In [1]:
%load_ext autoreload
%autoreload 2

import os

import numpy as np
import matplotlib.pyplot as plt
import pickle as pkl
from tqdm import tqdm

from moment.utils.experiment_utils import \
    get_dl4tsc_results, get_ts2vec_results, draw_cd_diagram

### Results from TS2Vec and DL4TSC on UCR datasets

In [2]:
ucr_results_unsupervised = get_ts2vec_results(database="ucr")
ucr_results_unsupervised.head()

Unnamed: 0_level_0,TS2Vec,T-Loss,TNC,TS-TCC,TST,DTW
Dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Adiac,0.762,0.675,0.726,0.767,0.55,0.604
ArrowHead,0.857,0.766,0.703,0.737,0.771,0.703
Beef,0.767,0.667,0.733,0.6,0.5,0.633
BeetleFly,0.9,0.8,0.85,0.8,1.0,0.7
BirdChicken,0.8,0.85,0.75,0.65,0.65,0.75


In [12]:
uea_results_unsupervised = get_ts2vec_results(database="uea")
uea_results_unsupervised.head()

Unnamed: 0_level_0,TS2Vec,T-Loss,TNC,TS-TCC,TST,DTW
Dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
ArticularyWordRecognition,0.987,0.943,0.973,0.953,0.977,0.987
AtrialFibrillation,0.2,0.133,0.133,0.267,0.067,0.2
BasicMotions,0.975,1.0,0.975,1.0,0.975,0.975
CharacterTrajectories,0.995,0.993,0.967,0.985,0.975,0.989
Cricket,0.972,0.972,0.958,0.917,1.0,1.0


In [10]:
ucr_results_supervised = get_dl4tsc_results(database="ucr")
ucr_results_supervised.head()

Unnamed: 0_level_0,CNN,Encoder,FCN,MCDNN,MLP,ResNet,t-LeNet,TWIESN
Dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
ACSF1,0.334,0.444,0.898,0.226,0.558,0.916,0.1,0.592
Adiac,0.39335,0.318159,0.841432,0.62046,0.391304,0.833248,0.022506,0.427621
AllGestureWiimoteX,0.411143,0.475143,0.713429,0.261429,0.476571,0.740571,0.1,0.522
AllGestureWiimoteY,0.478857,0.509429,0.784286,0.419714,0.570571,0.793714,0.1,0.600286
AllGestureWiimoteZ,0.375143,0.396,0.692,0.287143,0.439143,0.725714,0.1,0.516286


In [5]:
uea_results_supervised = get_dl4tsc_results(database="uea")
uea_results_supervised.head()

Unnamed: 0_level_0,CNN,Encoder,FCN,MCDNN,MCNN,MLP,ResNet,t-LeNet,TWIESN
Dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
50words,0.620879,0.723297,0.627473,0.589451,0.21956,0.684396,0.73956,0.125275,0.496044
Adiac,0.379028,0.484143,0.84399,0.610486,0.021995,0.396675,0.8289,0.02046,0.416368
ArrowHead,0.722857,0.804,0.842857,0.684571,0.339429,0.778286,0.844571,0.302857,0.658857
Beef,0.763333,0.643333,0.696667,0.563333,0.2,0.72,0.753333,0.2,0.536667
BeetleFly,0.89,0.745,0.86,0.58,0.5,0.87,0.85,0.5,0.73


In [6]:
# Join the two dataframes
ucr_results = ucr_results_unsupervised.merge(ucr_results_supervised, on='Dataset')
ucr_results.head()

Unnamed: 0_level_0,TS2Vec,T-Loss,TNC,TS-TCC,TST,DTW,CNN,Encoder,FCN,MCDNN,MLP,ResNet,t-LeNet,TWIESN
Dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
Adiac,0.762,0.675,0.726,0.767,0.55,0.604,0.39335,0.318159,0.841432,0.62046,0.391304,0.833248,0.022506,0.427621
ArrowHead,0.857,0.766,0.703,0.737,0.771,0.703,0.716571,0.629714,0.843429,0.677714,0.784,0.837714,0.302857,0.689143
Beef,0.767,0.667,0.733,0.6,0.5,0.633,0.766667,0.706667,0.68,0.506667,0.713333,0.753333,0.2,0.526667
BeetleFly,0.9,0.8,0.85,0.8,1.0,0.7,0.9,0.62,0.91,0.63,0.88,0.85,0.5,0.79
BirdChicken,0.8,0.85,0.75,0.65,0.65,0.75,0.71,0.51,0.94,0.54,0.74,0.88,0.5,0.62


In [6]:
experiment_name = "unsupervised_representation_learning" 

results_path = os.path.join("/home/extra_scratch/mgoswami/moment_results/", experiment_name)
print(f"Results path: {results_path}")

Results path: /home/extra_scratch/mgoswami/moment_results/unsupervised_representation_learning


In [7]:
dataset_with_results = [i for i in os.listdir(results_path) if 'results' in i]

train_accuracy = {}
test_accuracy = {}

for dataset in tqdm(dataset_with_results, total=len(dataset_with_results)):
    dataset_name = dataset.split("_")[1][:-4]
    full_path = os.path.join(results_path, dataset)
    with open(full_path, "rb") as f:
        r = pkl.load(f)
    
    train_accuracy[dataset_name] = r.train_accuracy
    test_accuracy[dataset_name] = r.test_accuracy

  0%|          | 0/29 [00:00<?, ?it/s]

100%|██████████| 29/29 [04:21<00:00,  9.03s/it]


In [8]:
import pandas as pd
MOMENT_results = pd.DataFrame([test_accuracy]).T
MOMENT_results.columns = ['MOMENT']
MOMENT_results.index.name = 'Dataset'
MOMENT_results.head()

Unnamed: 0_level_0,MOMENT
Dataset,Unnamed: 1_level_1
ArticularyWordRecognition,0.99
AtrialFibrillation,0.2
BasicMotions,1.0
Cricket,0.986111
DuckDuckGeese,0.6


In [13]:
results = MOMENT_results.merge(uea_results_unsupervised, on='Dataset')
# results = results.merge(uea_results_supervised, on='Dataset')

In [22]:
results

Unnamed: 0_level_0,MOMENT,TS2Vec,T-Loss,TNC,TS-TCC,TST,DTW
Dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
ArticularyWordRecognition,0.99,0.987,0.943,0.973,0.953,0.977,0.987
AtrialFibrillation,0.2,0.2,0.133,0.133,0.267,0.067,0.2
BasicMotions,1.0,0.975,1.0,0.975,1.0,0.975,0.975
Cricket,0.986111,0.972,0.972,0.958,0.917,1.0,1.0
DuckDuckGeese,0.6,0.68,0.65,0.46,0.38,0.62,0.6
EigenWorms,0.80916,0.847,0.84,0.84,0.779,0.748,0.618
Epilepsy,0.992754,0.964,0.971,0.957,0.957,0.949,0.964
ERing,0.959259,0.874,0.133,0.852,0.904,0.874,0.133
EthanolConcentration,0.357414,0.308,0.205,0.297,0.285,0.262,0.323
FaceDetection,0.632804,0.501,0.513,0.536,0.544,0.534,0.529


In [26]:
print(results.describe().to_latex(float_format="%.3f"))

\begin{tabular}{lrrrrrrr}
\toprule
 & MOMENT & TS2Vec & T-Loss & TNC & TS-TCC & TST & DTW \\
\midrule
count & 29.000 & 29.000 & 29.000 & 29.000 & 29.000 & 29.000 & 28.000 \\
mean & 0.670 & 0.694 & 0.646 & 0.660 & 0.657 & 0.605 & 0.638 \\
std & 0.274 & 0.255 & 0.296 & 0.267 & 0.263 & 0.294 & 0.296 \\
min & 0.200 & 0.200 & 0.133 & 0.133 & 0.243 & 0.067 & 0.133 \\
25% & 0.411 & 0.501 & 0.451 & 0.469 & 0.460 & 0.408 & 0.456 \\
50% & 0.722 & 0.683 & 0.676 & 0.746 & 0.751 & 0.620 & 0.664 \\
75% & 0.909 & 0.928 & 0.905 & 0.911 & 0.904 & 0.850 & 0.914 \\
max & 1.000 & 0.989 & 1.000 & 0.979 & 1.000 & 1.000 & 1.000 \\
\bottomrule
\end{tabular}



In [15]:
results.to_latex("../../assets/results/zero_shot/multi_variate_classification.tex", multicolumn_format='c', float_format="%.3f")

In [None]:
# Read TimesNet and GPT4TS results
timesnet_gpt4ts_results = pd.read_csv('../../assets/results/finetuning/timesnet_gpt4ts_classification.csv')
timesnet_gpt4ts_results = timesnet_gpt4ts_results.drop(columns=['Wandb Run (TimesNet)', 'Wandb Run (GPT4TS)'])
timesnet_results = timesnet_gpt4ts_results[['Dataset', 'TimesNet Test Accuracy']].set_index('Dataset')
timesnet_results.columns = ['TimesNet']
gpt4ts_results = timesnet_gpt4ts_results[['Dataset', 'GPT4TS Test Accuracy']].set_index('Dataset')
gpt4ts_results.columns = ['GPT4TS']

In [None]:
results = MOMENT_results.merge(ucr_results, on='Dataset')
results = results.merge(timesnet_results, on='Dataset')
results = results.merge(gpt4ts_results, on='Dataset')

In [None]:
results = MOMENT_results.merge(ucr_results, on='Dataset')
results = results.merge(timesnet_results, on='Dataset')
results = results.merge(gpt4ts_results, on='Dataset')
results = results[[
    'MOMENT', 'TimesNet', 'GPT4TS', 
    'TS2Vec', 'T-Loss', 'TNC', 'TS-TCC', 'TST', 
    'CNN', 'Encoder', 'FCN', 'MCDNN', 'MLP', 'ResNet', 't-LeNet', 'TWIESN',
    'DTW']]
results.to_csv("../../assets/results/zero_shot/unsupervised_representation_learning.csv", index=False)
results.head()

In [16]:
# Average rank of each method on each dataset
average_rank = results.rank(axis=1, method='average', ascending=False)
average_rank.head()

Unnamed: 0_level_0,MOMENT,TS2Vec,T-Loss,TNC,TS-TCC,TST,DTW
Dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
ArticularyWordRecognition,1.0,2.5,7.0,5.0,6.0,4.0,2.5
AtrialFibrillation,3.0,3.0,5.5,5.5,1.0,7.0,3.0
BasicMotions,2.0,5.5,2.0,5.5,2.0,5.5,5.5
Cricket,3.0,4.5,4.5,6.0,7.0,1.5,1.5
DuckDuckGeese,4.5,1.0,2.0,6.0,7.0,3.0,4.5


In [17]:
average_rank.describe()

Unnamed: 0,MOMENT,TS2Vec,T-Loss,TNC,TS-TCC,TST,DTW
count,29.0,29.0,29.0,29.0,29.0,29.0,28.0
mean,3.465517,2.862069,3.603448,4.362069,4.12069,5.068966,4.428571
std,1.986252,1.831747,2.106026,1.597412,2.05572,1.85031,1.642685
min,1.0,1.0,1.0,1.0,1.0,1.5,1.5
25%,2.0,1.0,2.0,3.0,2.0,3.5,3.0
50%,3.0,2.5,4.0,5.0,5.0,5.5,4.5
75%,5.0,3.5,5.0,5.5,6.0,7.0,5.625
max,7.0,7.0,7.0,7.0,7.0,7.0,7.0


In [19]:
# Compute number so wins / ties / losses for each method 
wins = (results.rank(axis=1, method='average', ascending=True) - 1).sum(axis=0)
losses = (results.rank(axis=1, method='average', ascending=False) - 1).sum(axis=0)

In [20]:
wins

MOMENT    101.5
TS2Vec    119.0
T-Loss     97.5
TNC        75.5
TS-TCC     82.5
TST        55.0
DTW        72.0
dtype: float64

In [21]:
losses

MOMENT     71.5
TS2Vec     54.0
T-Loss     75.5
TNC        97.5
TS-TCC     90.5
TST       118.0
DTW        96.0
dtype: float64

In [None]:
results.to_latex("../../assets/results/zero_shot/classification.tex", multicolumn_format='c', float_format="%.3f")

In [None]:
boxprops = dict(linestyle='-', linewidth=1, color='k')
flierprops = dict(marker='o', markersize=12, markeredgecolor='darkgreen')
medianprops = dict(linestyle='-', linewidth=2, color='blue')
meanpointprops = dict(marker='D', markeredgecolor='black',
                      markerfacecolor='firebrick')
meanlineprops = dict(linestyle='--', linewidth=2, color='red')

model_names = results.columns.tolist()

fig = plt.figure(figsize=(10, 6))  # Specify the size of the figure
_ = plt.boxplot(results,
                labels=model_names, 
                meanline=True, 
                showmeans=True, 
                notch=True,
                bootstrap=10000,
                flierprops=flierprops,
                meanprops=meanlineprops, 
                boxprops=boxprops,
                medianprops=medianprops,
                )

plt.grid(color='lightgray', linestyle='--', linewidth=0.5) 
plt.ylabel("Accuracy", fontsize=14)
plt.xticks(rotation=45, ha='right', fontsize=14)
plt.yticks(fontsize=14)
plt.title("Accuracy on UCR datasets", fontsize=16)
plt.show()

In [None]:
results.reset_index(inplace=True)
long_results = results.melt(id_vars=['Dataset'], value_vars=model_names)
long_results.columns= ['dataset_name', 'classifier_name', 'accuracy']
long_results = long_results[['classifier_name', 'dataset_name', 'accuracy']]
long_results.head()

In [None]:
plt, p_values, average_ranks = draw_cd_diagram(df_perf = long_results, alpha = 0.05, labels='Accuracy')
plt.show()

In [None]:
### Results summary
columns = ['MOMENT', 'TS2Vec', 'T-Loss', 'TNC', 'TS-TCC', 'TST', 
 'CNN', 'Encoder', 'FCN', 'MCDNN', 'MLP', 'ResNet', 't-LeNet', 'TWIESN', 
 'DTW']
results[columns].fillna(0).describe()

In [None]:
summary = pd.concat(
    [results[columns].mean(axis=0, skipna=True).astype(np.float16),
     results[columns].median(axis=0, skipna=True).astype(np.float16),
     results[columns].std(axis=0, skipna=True).astype(np.float16)], axis=1).T
summary.index = ['Mean', 'Median', 'Std.']
summary

In [None]:
print(summary.to_latex(float_format="%.3f"))

# Analysis

In [None]:
### Datasets with worst performance in comarison to TS2Vec
(results['TS2Vec'] - results['MOMENT']).sort_values(ascending=False)[:10]

In [None]:
import pandas as pd
summary = pd.read_csv("../../assets/data/summaryUnivariate.csv")
summary.head()

In [None]:
# low_accuracy_datasets = sorted(test_accuracy, key=test_accuracy.get, reverse=False)[:15]
low_accuracy_datasets = (results['TS2Vec'] - results['MOMENT']).sort_values(ascending=False)[:10].index.tolist()
low_accuracy_datasets

# Analyze low accuracy datasets

In [None]:
summary[summary["problem"].isin(low_accuracy_datasets)]

In [None]:
# Transform this from a dictionary to a dataframe
accuracies = pd.DataFrame(data=[test_accuracy, train_accuracy]).T
accuracies.columns = ["Test accuracy", "Train accuracy"]
accuracies = accuracies.merge(summary, left_index=True, right_on="problem")

In [None]:
accuracies.head()

In [None]:
plt.scatter(accuracies.loc[:, "Test accuracy"], accuracies.loc[:, "numTrainCases"])
plt.xlabel("Test accuracy", fontsize=16)
plt.ylabel("Number of training cases", fontsize=16)
plt.title("Test accuracy vs number of training cases", fontsize=18)
plt.ylim(0, 1000)

In [None]:
plt.scatter(accuracies.loc[:, "Test accuracy"], accuracies.loc[:, "seriesLength"])
plt.xlabel("Test accuracy", fontsize=16)
plt.ylabel("Series Length", fontsize=16)
plt.ylim(0, 600)
plt.title("Test accuracy vs series length", fontsize=18)

### Fine-tuning

In [None]:
import torch

from moment.utils.config import Config
from moment.utils.utils import parse_config
from moment.data.dataloader import get_timeseries_dataloader
from moment.models.base import BaseModel
from moment.models.moment import MOMENT

In [None]:
def get_dataloaders(args):
    args.dataset_names = args.full_file_path_and_name
    args.data_split = 'train'
    train_dataloader = get_timeseries_dataloader(args=args)
    args.data_split = 'test'
    test_dataloader = get_timeseries_dataloader(args=args)
    args.data_split = 'val'
    val_dataloader = get_timeseries_dataloader(args=args)
    return train_dataloader, test_dataloader, val_dataloader

def load_pretrained_moment(args,
                         pretraining_task_name: str = "pre-training"):
    args.task_name = pretraining_task_name
        
    checkpoint = BaseModel.load_pretrained_weights(
        run_name=args.pretraining_run_name, 
        opt_steps=args.pretraining_opt_steps)
    
    pretrained_model = MOMENT(configs=args)
    pretrained_model.load_state_dict(checkpoint["model_state_dict"])
    
    return pretrained_model

def freeze_model_parameters(args, model):
    if args.finetuning_mode == 'linear-probing':
        for name, param in model.named_parameters():
            name = name.lower()
            if 'ln' in name or 'norm' in name or 'layer_norm' in name:
                param.requires_grad = True
            elif 'wpe' in name or 'position_embeddings' in name or 'pos_drop' in name:
                param.requires_grad = True
            elif 'mlp' in name or 'densereludense' in name:
                param.requires_grad = False
            elif 'attn' in name or 'selfattention' in name:
                param.requires_grad = False
            elif 'head' in name:
                param.requires_grad = True
            elif 'patch_embedding' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False

    print("====== Frozen parameter status ======")
    for name, param in model.named_parameters():
        if param.requires_grad:
            print("Not frozen:", name)
        else:
            print("Frozen:", name)
    print("=====================================")
    return model

In [None]:
config_path = "../../configs/classification/unsupervised_representation_learning.yaml"
DEFAULT_CONFIG_PATH = "../../configs/default.yaml"
gpu_id = 0

# Load arguments and parse them
config = Config(config_file_path=config_path, 
            default_config_file_path=DEFAULT_CONFIG_PATH).parse()

config['device'] = torch.device('cuda:{}'.format(gpu_id)) if torch.cuda.is_available() else 'cpu'
args = parse_config(config)

args.full_file_path_and_name = '/TimeseriesDatasets/classification/UCR/Beef/Beef_TEST.ts'
args.max_epoch = 20
args.batch_size = 16
args.init_lr = 0.0001
args.upsampling_type = 'interpolate' 
# args.upsampling_type = 'pad' # 'interpolate' 'pad
args.finetuning_mode = 'linear-probing'

In [None]:
model = load_pretrained_moment(args)
model = freeze_model_parameters(args, model)
model.to(args.device)

In [None]:
args.task_name = "classification"
train_dataloader, test_dataloader, val_dataloader = get_dataloaders(args)

In [None]:
trues = []
preds = []
input_masks = []

with torch.no_grad():
    for batch_x in tqdm(test_dataloader, total=len(test_dataloader)):
        timeseries = batch_x.timeseries.float().to(args.device)
        input_mask = batch_x.input_mask.long().to(args.device)

        outputs = model.reconstruct(
            x_enc=timeseries, input_mask=input_mask)
        
        preds.append(outputs.reconstruction.detach().cpu().numpy())
        trues.append(timeseries.detach().cpu().numpy())
        input_masks.append(input_mask.detach().cpu().numpy())

    trues = np.concatenate(trues, axis=0).squeeze()
    preds = np.concatenate(preds, axis=0).squeeze()
    input_masks = np.concatenate(input_masks, axis=0).squeeze()

In [None]:
idx = np.random.randint(0, len(trues))
plt.title(f"idx: {idx}")
plt.plot(trues[idx], label="True")
plt.plot(preds[idx], label="Predicted")
plt.plot(input_masks[idx], label="Input mask")
plt.legend()
plt.show()

In [None]:
from tqdm import tqdm
from moment.utils.short_univariate_classification_datasets import \
    short_univariate_classification_datasets

args.task_name = "classification"
features = []
for dataset_name in tqdm(short_univariate_classification_datasets):
    args.full_file_path_and_name = dataset_name
    train_dataloader, test_dataloader, val_dataloader = get_dataloaders(args)
    
    train_data = np.concatenate([
        train_dataloader.dataset.data, val_dataloader.dataset.data], axis=1)
    labels = np.concatenate([train_dataloader.dataset.labels, val_dataloader.dataset.labels])
    num_classes = len(np.unique(labels.flatten()))
    
    len_timeseries, n_train = train_data.shape
    len_timeseries, n_test = test_dataloader.dataset.data.shape

    features.append([dataset_name.split("/")[-2], n_train, n_test, len_timeseries, num_classes])
    

In [None]:
features = pd.DataFrame(features, columns=['problem', 'num_train', 'num_test', 'series_length', 'num_classes'])

In [None]:
feature_comparison = summary.merge(features, on='problem')
feature_comparison.head()

In [None]:
class_bool = feature_comparison.numClasses.astype(int) != feature_comparison.num_classes.astype(int)
train_bool = feature_comparison.num_train.astype(int) != feature_comparison.numTrainCases.astype(int)
test_bool = feature_comparison.num_test.astype(int) != feature_comparison.numTestCases.astype(int)
length_bool = feature_comparison.series_length.astype(int) != feature_comparison.seriesLength.astype(int)

print(' Num. classes:', class_bool.sum())
print('  Train cases:', train_bool.sum())
print('   Test cases:', test_bool.sum())
print('Series length:', length_bool.sum())

In [None]:
mismatches = feature_comparison[train_bool | test_bool | length_bool]
mismatches

In [None]:
results.merge(mismatches, right_on='problem', left_index=True)

In [None]:
import time

import wandb
from torch import nn
from torch import optim
from tqdm import trange, tqdm

from moment.common import PATHS


def train(args, model, train_dataloader):
        n_train_epochs = args.max_epoch
        
        # Training loop
        tr_loss = 0
        
        optimizer = optim.AdamW(model.parameters(), 
                                lr=args.init_lr,
                                weight_decay=args.weight_decay)

        criterion = nn.MSELoss() 

        logger = wandb.init(
            project="Time-series Foundation Model",
            dir=PATHS.WANDB_DIR)
        
        for epoch in trange(n_train_epochs):
            for batch in tqdm(train_dataloader, total=len(train_dataloader)):
                timeseries = batch.timeseries.float().to(args.device)
                input_mask = batch.input_mask.long().to(args.device)

                model.train()
                # Training step
                outputs = model.reconstruct(x_enc=timeseries, 
                                input_mask=input_mask, mask=None)
                
                loss = criterion(outputs.reconstruction, timeseries)

                if not np.isnan(float(loss)):
                    loss.backward()
                
                logger.log({
                     'step_loss_train': loss.item(),
                     'lr': optimizer.param_groups[0]['lr']})
                
                nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
                    
                optimizer.step()
                optimizer.zero_grad()
                
                tr_loss += loss.detach().cpu().numpy()

        logger.finish()

        return model

def get_embeddings_and_labels(model : torch.nn.Module, 
                              dataloader : torch.utils.data.DataLoader,
                              device : torch.device, 
                              enable_batchwise_pbar : bool = False):
    model = model.to(device)
    model.eval()

    embeddings = []
    labels = []

    with torch.no_grad():
        for batch_x in tqdm(dataloader, total=len(dataloader), 
                            disable=(not enable_batchwise_pbar)):
            timeseries = batch_x.timeseries.float().to(device)
            input_mask = batch_x.input_mask.long().to(device)

            outputs = model.embed(x_enc=timeseries, input_mask=input_mask, reduction='mean')
            
            embeddings_ = outputs.embeddings.detach().cpu().numpy()
            embeddings.append(embeddings_)
            labels.append(batch_x.labels)

        embeddings = np.concatenate(embeddings, axis=0)
        labels = np.concatenate(labels, axis=0).squeeze()
 
    return embeddings, labels

In [None]:
# model = train(args, model, train_dataloader)

In [None]:
from moment.models.statistical_classifiers import fit_svm

train_embeddings, train_labels = get_embeddings_and_labels(
        model=model, dataloader=train_dataloader, 
        device=torch.device(args.device), 
        enable_batchwise_pbar=False)
    
test_embeddings, test_labels = get_embeddings_and_labels(
    model=model, dataloader=test_dataloader, 
    device=torch.device(args.device), 
    enable_batchwise_pbar=False)

val_embeddings, val_labels = get_embeddings_and_labels(
    model=model, dataloader=val_dataloader, 
    device=torch.device(args.device), 
    enable_batchwise_pbar=False)

train_embeddings = np.concatenate([train_embeddings, val_embeddings], axis=0)
train_labels = np.concatenate([train_labels, val_labels], axis=0)

classifier = fit_svm(features=train_embeddings, y=train_labels)

In [None]:
# Evaluate the model
test_accuracy = classifier.score(test_embeddings, test_labels)
print(f"Test accuracy: {test_accuracy}")

In [None]:
results[results.index == 'Beef']

In [None]:
SMALL_IMAGE_DATASETS = ['Crop', 'MedicalImages', 'SwedishLeaf', 
                        'FacesUCR', 'FaceAll', 'Adiac', 'ArrowHead']
SMALL_SPECTRO_DATASETS = ['Wine', 'Strawberry', 'Coffee', 'Ham', 'Meat', 'Beef']

['ProximalPhalanxTW', 'ProximalPhalanxOutlineCorrect', 'ProximalPhalanxOutlineAgeGroup',
 'PhalangesOutlinesCorrect', 'MiddlePhalanxTW', 'MiddlePhalanxOutlineCorrect', 'MiddlePhalanxOutlineAgeGroup',
 'DistalPhalanxTW', 'DistalPhalanxOutlineCorrect', 'DistalPhalanxOutlineAgeGroup']