In [1]:
import pandas as pd
import pandas as pd
import torch
import pickle
import os

import nlpsig
from nlpsig.deepsignet import StackedDeepSigNet
from nlpsig.focal_loss import FocalLoss, ClassBalanced_FocalLoss

In [2]:
def concatenate_data(data_folder_path):
    print(f"looking in {data_folder_path} directory...")
    manifesto_dfs = []
    for filename in os.listdir(data_folder_path):
        print(f"- reading {filename}...")
        # parse filename for metadata
        filename_split = filename.split("_")
        party_id = int(filename_split[0])
        year = int(filename_split[1][0:4])
        month = int(filename_split[1][4:6])
        doc_id = f"{party_id}_{year}"
        # read dataframe and add metadata
        df = pd.read_csv(f"data/{filename}")[["text", "cmp_code"]]
        df = df[df["cmp_code"]!="H"].dropna().reset_index(drop=True)
        df["topic"] = [int(str(code)[0]) for code in df["cmp_code"]]
        df["switched_topic"] = [1] + [int(df["topic"].iloc[i]!=df["topic"].iloc[i-1])
                                      for i in range(1, len(df))]
        df["party_id"] = party_id
        df["doc_id"] = f"{party_id}_{year}"
        df["datetime"] = pd.Timestamp(f"{year}-{month}")
        manifesto_dfs.append(df)
    return pd.concat(manifesto_dfs).reset_index(drop=True)

In [3]:
manifesto_df = concatenate_data("data/")

looking in data/ directory...
- reading 51902_201706.csv...
- reading 51902_201505.csv...
- reading 51320_201912.csv...
- reading 51620_201706.csv...
- reading 51620_201505.csv...
- reading 51421_201706.csv...
- reading 51421_201505.csv...
- reading 51421_201912.csv...
- reading 51902_201912.csv...
- reading 51320_201706.csv...
- reading 51620_201912.csv...
- reading 51320_201505.csv...


In [4]:
manifesto_df.head(10)

Unnamed: 0,text,cmp_code,topic,switched_topic,party_id,doc_id,datetime
0,SNP MPs have used their influence to deliver p...,305.1,3,1,51902,51902_2017,2017-06-01
1,Here’s just some of what a strong team of SNP ...,305.1,3,0,51902,51902_2017,2017-06-01
2,When the Scotland Bill was going through Westm...,301.0,3,0,51902,51902_2017,2017-06-01
3,"And it was SNP MPs, working with the Scottish ...",301.0,3,0,51902,51902_2017,2017-06-01
4,The SNP secured a deal that ensures Scotland w...,301.0,3,0,51902,51902_2017,2017-06-01
5,SNP MPs have consistently opposed Tory austerity.,504.0,5,1,51902,51902_2017,2017-06-01
6,Our MPs have been instrumental in forcing UK g...,504.0,5,0,51902,51902_2017,2017-06-01
7,Alison Thewliss has been at the forefront of t...,504.0,5,0,51902,51902_2017,2017-06-01
8,and force women to prove they have been raped ...,503.0,5,0,51902,51902_2017,2017-06-01
9,SNP MPs have worked with Women Against State P...,503.0,5,0,51902,51902_2017,2017-06-01


In [5]:
manifesto_df["switched_topic"].value_counts()

0    10746
1     4864
Name: switched_topic, dtype: int64

In [6]:
manifesto_df["topic"].value_counts()

5    5300
4    3872
6    1547
1    1504
2    1235
3    1101
7    1024
0      27
Name: topic, dtype: int64

In [7]:
manifesto_df["party_id"].value_counts()

51421    4515
51620    4307
51320    4039
51902    2749
Name: party_id, dtype: int64

In [8]:
manifesto_df["doc_id"].value_counts()

51421_2015    1917
51320_2019    1702
51620_2015    1588
51620_2017    1496
51421_2019    1467
51320_2017    1328
51620_2019    1223
51421_2017    1131
51902_2019    1071
51320_2015    1009
51902_2015     892
51902_2017     786
Name: doc_id, dtype: int64

## Model specifics

Nested dictionary for models specifications.

This includes models for encoding text, path signature and etc.

In [84]:
model_specifics = {
    "encoder_args": {
        "feature_name": "text", # column corresponding to the sentences
        "model_name": "all-mpnet-base-v2", #options: all-mpnet-base-v2, all-distilroberta-v1, all-MiniLM-L12-v2
        "model_encoder_args": {
            "batch_size": 64,
            "show_progress_bar": True,
            "output_value": 'sentence_embedding', 
            "convert_to_numpy": True,
            "convert_to_tensor": False,
            "device": None,
            "normalize_embeddings": False
        }
    },
    "dim_reduction": {
        "method": 'umap', #options: ppapca, ppapcappa, umap
        "n_components": 50, # options: any int number between 1 and embedding dimensions
    },
    "embedding":{
        "global_embedding_tp": 'SBERT', #options: SBERT, BERT_cls , BERT_mean, BERT_max
        "post_embedding_tp": 'sentence', #options: sentence, reduced
        "feature_combination_method": 'attention', #options concatenation, attention 
    },
    "time_injection": {
        "history_tp": 'timestamp', #options: timestamp, None
        "post_tp": 'timestamp', #options: timestamp, timediff, None
    },
    "signature": {
        "dimensions": 3, #options: any int number larger than 1
        "method": 'log', # options: log, sig
        "interval": 1/12
    }
}

## Obtaining SBERT Embeddings

We can use the `SentenceEncoder` class within `nlpsig` to obtain sentence embeddings from a model. Here, we have defined the encoder arguments in `model_specifics`.

In [85]:
model_specifics["encoder_args"]

{'feature_name': 'text',
 'model_name': 'all-mpnet-base-v2',
 'model_encoder_args': {'batch_size': 64,
  'show_progress_bar': True,
  'output_value': 'sentence_embedding',
  'convert_to_numpy': True,
  'convert_to_tensor': False,
  'device': None,
  'normalize_embeddings': False}}

We can pass these into the constructor of the class to initialise our text encoder as follows:

In [86]:
model_specifics["encoder_args"]["model_name"]

'all-mpnet-base-v2'

In [87]:
# initialise the Text Encoder 
text_encoder = nlpsig.SentenceEncoder(df=manifesto_df,
                                      **model_specifics["encoder_args"])
# load pretrained model (model_specifics["encoder_args"]["model_name"])
text_encoder.load_pretrained_model()

The class has a `.encode_sentence_transformer()` method which first loads in the model (using the `model_name` and `model_args` attributes) and then obtains an embedding for each sentence. These sentence embeddings are then stored in the `embeddings_sentence` attribute of the object.

In [88]:
text_encoder.obtain_embeddings()
embeddings_sentence = text_encoder.sentence_embeddings

[INFO] number of sentences to encode: 15610


Batches:   0%|          | 0/244 [00:00<?, ?it/s]

## Dimensionality Reduction with UMAP

Here we specified our choices in `model_specifics` above:

In [89]:
model_specifics["dim_reduction"]

{'method': 'umap', 'n_components': 50}

In [90]:
reduction = nlpsig.DimReduce(**model_specifics["dim_reduction"])
embeddings_reduced = reduction.fit_transform(embeddings_sentence)

In [91]:
print(embeddings_sentence.shape)
print(embeddings_reduced.shape)

(15610, 768)
(15610, 50)


## Data preparation: Time injection and Padding

In [92]:
manifesto_data = nlpsig.PrepareData(manifesto_df,
                                    id_column="doc_id",
                                    labels_column="switched_topic",
                                    embeddings=embeddings_sentence,
                                    embeddings_reduced=embeddings_reduced)

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...


In [93]:
manifesto_data.df["doc_id"].value_counts()

51421_2015    1917
51320_2019    1702
51620_2015    1588
51620_2017    1496
51421_2019    1467
51320_2017    1328
51620_2019    1223
51421_2017    1131
51902_2019    1071
51320_2015    1009
51902_2015     892
51902_2017     786
Name: doc_id, dtype: int64

In [94]:
manifesto_data.df[manifesto_data.df["doc_id"]=="51902_2017"]

Unnamed: 0,text,cmp_code,topic,switched_topic,party_id,doc_id,datetime,d1,d2,d3,...,e762,e763,e764,e765,e766,e767,e768,time_encoding,time_diff,timeline_index
13753,SNP MPs have used their influence to deliver p...,305.1,3,1,51902,51902_2017,2017-06-01,6.218624,5.405863,7.141194,...,0.015981,-0.014767,0.026448,0.002084,-0.017059,0.001952,0.015159,2017.413699,0.0,0
13754,Here’s just some of what a strong team of SNP ...,305.1,3,0,51902,51902_2017,2017-06-01,6.121441,5.327885,7.136108,...,-0.012643,-0.033004,0.015802,-0.003156,-0.017330,0.027937,0.013548,2017.413699,0.0,1
13755,When the Scotland Bill was going through Westm...,301,3,0,51902,51902_2017,2017-06-01,6.353848,5.650436,7.309916,...,-0.022350,-0.015041,-0.038259,-0.000943,-0.023851,-0.018221,0.020819,2017.413699,0.0,2
13756,"And it was SNP MPs, working with the Scottish ...",301,3,0,51902,51902_2017,2017-06-01,6.474751,5.557942,7.597468,...,0.029210,0.002856,-0.022373,0.014007,-0.006234,0.015914,0.018991,2017.413699,0.0,3
13757,The SNP secured a deal that ensures Scotland w...,301,3,0,51902,51902_2017,2017-06-01,6.893159,5.808902,7.574815,...,0.031720,-0.014070,0.001398,0.028775,0.021897,0.018755,-0.021968,2017.413699,0.0,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
14534,The disgraceful condition of the housing provi...,201.2,2,0,51902,51902_2017,2017-06-01,5.152985,4.770663,6.742881,...,-0.011745,0.008737,-0.005464,0.014512,-0.022729,-0.041512,0.049706,2017.413699,0.0,781
14535,The Scottish Government’s work to resettle Syr...,201.2,2,0,51902,51902_2017,2017-06-01,4.823460,5.087404,6.052791,...,0.033104,-0.021670,-0.040489,0.022394,-0.003573,0.060122,0.058902,2017.413699,0.0,782
14536,We will urge the UK government to work with th...,301,3,1,51902,51902_2017,2017-06-01,4.777165,4.842837,6.092124,...,0.029256,-0.005607,-0.014617,0.028316,-0.026071,-0.015531,0.037127,2017.413699,0.0,783
14537,rather than use private contractors who have p...,413,4,1,51902,51902_2017,2017-06-01,5.953185,4.654902,6.219607,...,0.021102,0.066897,-0.031050,-0.022594,-0.004096,-0.021181,0.018906,2017.413699,0.0,784


## Obtaining path by looking at post history

In [95]:
time_features = ["timeline_index", "time_encoding"]
history_path = manifesto_data.pad(pad_by="history",
                                  method="k_last",
                                  zero_padding=True,
                                  k=10,
                                  time_feature=time_features,
                                  standardise_time_feature=False,
                                  embeddings="dim_reduced")

[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


100%|████████████████████████████████████████████████████████████████████████████████████████████████| 15610/15610 [00:23<00:00, 653.79it/s]


In [96]:
history_path.shape

(15610, 10, 54)

In [99]:
manifesto_data.label_column

'switched_topic'

In [100]:
history_path[3]

array([[0, 2015.3287671232877, 5.585601329803467, 5.373188495635986,
        7.216732025146484, 6.7016377449035645, 4.010620594024658,
        1.9871352910995483, 4.562832355499268, 7.122035026550293,
        4.239704608917236, 3.925722599029541, 5.707393646240234,
        6.722735404968262, 3.756204128265381, 3.753582715988159,
        0.9717355370521545, 6.577835559844971, 5.097364902496338,
        5.210720539093018, 1.648322343826294, 6.5491838455200195,
        2.582465171813965, 3.3948566913604736, 3.4996731281280518,
        5.060497283935547, 4.592323303222656, 4.1421027183532715,
        5.254263401031494, 5.300419807434082, 4.375033855438232,
        5.108604431152344, 5.087096691131592, 5.567975997924805,
        8.174856185913086, 3.72859787940979, 3.9155287742614746,
        4.524398326873779, 5.853325843811035, 3.465318441390991,
        5.750558376312256, 7.4231767654418945, 4.017478942871094,
        4.7644829750061035, 4.94251012802124, 5.548069953918457,
        3.693

In [101]:
x_data, input_channels = manifesto_data.get_torch_path_for_SDSN(
    include_time_features_in_path = True,
    include_time_features_in_input = True,
    include_embedding_in_input = True,
    reduced_embeddings=False
)

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


In [102]:
x_data.shape

torch.Size([15610, 10, 822])

In [104]:
input_channels+len(time_features)+768

822

## StackedDeepSigNet

In [108]:
seed = 2022
SDSN_args = {
    "input_channels": input_channels,
    "output_channels": 10,
    "num_time_features": len(time_features),
    "embedding_dim": x_data.shape[2]-input_channels-len(time_features),
    "sig_depth": 3,
    "hidden_dim_lstm": (12, 8),
    "hidden_dim": 32,
    "output_dim": len(manifesto_data.df["switched_topic"].unique()),
    "dropout_rate": 0.25,
    "augmentation_type": "Conv1d",
    "augmentation_layers": (),
    "blocks": 2,
    "BiLSTM": False,
    "comb_method": "concatenation"
}

In [109]:
data_folds = nlpsig.Folds(x_data=x_data,
                          y_data=torch.tensor(manifesto_data.df["switched_topic"]),
                          n_splits=2,
                          shuffle=True,
                          random_state=seed)

In [121]:
train, valid, test = data_folds.get_splits(fold_index = 1,
                                           as_DataLoader = True,
                                           data_loader_args = {"batch_size": 512,
                                                               "shuffle": True})

In [122]:
# initial model definitions
model = StackedDeepSigNet(**SDSN_args)

# define loss
loss = "cross_entropy"
num_folds = 5
gamma = 2
beta = 0.999

if loss == "focal":    
    criterion = FocalLoss(gamma = gamma)
    y_train = data_folds.get_splits(fold_index=0)[5]
    criterion.set_alpha_from_y(y=y_train)
elif loss == "cbfocal":
    criterion = ClassBalanced_FocalLoss(gamma = gamma,
                                        beta = beta,
                                        no_of_classes = 2)
    y_train = data_folds.get_splits(fold_index=0)[5]
    criterion.set_samples_per_cls_from_y(y=y_train)
elif loss == "cross_entropy":
    criterion = torch.nn.CrossEntropyLoss()
    
# define optimizer
learning_rate = 0.00005
weight_decay_adam = 0.0001
optimizer = torch.optim.Adam(model.parameters(),
                             lr=learning_rate)

In [123]:
model.eval()
labels_all = torch.empty((0))
predicted_all = torch.empty((0))
with torch.no_grad():
    # Iterate through test dataset
    for emb_t, labels_t in test:
        # make prediction
        outputs_t = model(emb_t)
        
        _, predicted_t = torch.max(outputs_t.data, 1)
        # save predictions and labels
        labels_all = torch.cat([labels_all, labels_t])
        predicted_all = torch.cat([predicted_all, predicted_t])

In [124]:
torch.max(outputs_t, 1)

torch.return_types.max(
values=tensor([-1.2756e+00,  2.4272e+00, -3.2802e+00,  2.1688e+00,  2.4247e+00,
        -3.4462e+00,  1.4838e+00, -8.6385e-01,  2.4163e+00, -3.3502e+00,
         2.3720e+00,  2.4298e+00,  1.7554e+00,  2.8861e-01,  2.4336e+00,
         2.3991e+00, -2.0071e+00, -2.6585e+00, -2.4787e+00, -2.1174e-01,
         1.9039e+00,  2.4216e+00, -3.9033e+00,  2.3696e+00,  2.3864e+00,
         2.1087e+00,  2.4036e+00,  8.0496e-01,  2.3953e+00,  6.7340e-01,
        -8.0522e-01,  2.4265e+00,  2.4281e+00,  1.6523e-01,  2.4231e+00,
         2.4048e+00,  1.9673e+00,  2.3854e+00,  2.3859e+00,  2.4259e+00,
         2.1200e+00, -5.5159e-01, -3.8687e-01,  2.3187e-01, -3.5015e+00,
         2.7930e-01,  2.3806e+00,  6.7263e-01,  2.3914e+00, -6.8284e-01,
         2.1221e+00,  5.5558e-01, -4.6560e-01, -2.4544e+00,  2.3911e+00,
         2.4070e+00, -2.2983e+00,  2.0891e+00,  2.4128e+00,  1.4829e+00,
         2.3805e+00, -4.2980e-01, -3.3113e+00,  2.4217e+00, -1.3584e+00,
         2.4201e+00,

In [125]:
trained_model = nlpsig.training_pytorch(model = model,
                                        train_loader = train,
                                        valid_loader = valid,
                                        criterion = criterion,
                                        optimizer = optimizer,
                                        num_epochs = 1000,
                                        seed = seed,
                                        patience = 10,
                                        verbose = True,
                                        verbose_epoch = 20)

  0%|                                                                                                              | 0/1000 [00:00<?, ?it/s]

Epoch: 1/1000 || Item: 0/11 || Loss: 2.8005354404449463


  0%|                                                                                                      | 1/1000 [00:01<17:30,  1.05s/it]

--------------------------------------------------
##### Epoch: 1/1000 || Loss: 2.5524680614471436
--------------------------------------------------
Epoch: 1 || Loss: 2.4225887854894004 || Accuracy: 0.6948757767677307 || F1-score: 0.40998625744388456


  2%|██                                                                                                   | 20/1000 [00:20<16:25,  1.01s/it]

Epoch: 21/1000 || Item: 0/11 || Loss: 0.6099115014076233


  2%|██                                                                                                   | 21/1000 [00:21<16:22,  1.00s/it]

--------------------------------------------------
##### Epoch: 21/1000 || Loss: 0.6370500326156616
--------------------------------------------------
Epoch: 21 || Loss: 0.6043670276800791 || Accuracy: 0.6948757767677307 || F1-score: 0.40998625744388456


  4%|████                                                                                                 | 40/1000 [00:40<16:08,  1.01s/it]

Epoch: 41/1000 || Item: 0/11 || Loss: 0.6057258248329163


  4%|████▏                                                                                                | 41/1000 [00:41<16:04,  1.01s/it]

--------------------------------------------------
##### Epoch: 41/1000 || Loss: 0.639003574848175
--------------------------------------------------
Epoch: 41 || Loss: 0.6047622760136923 || Accuracy: 0.6948757767677307 || F1-score: 0.40998625744388456


  6%|██████                                                                                               | 60/1000 [01:00<15:46,  1.01s/it]

Epoch: 61/1000 || Item: 0/11 || Loss: 0.5969003438949585


  6%|██████▏                                                                                              | 61/1000 [01:01<15:43,  1.01s/it]

--------------------------------------------------
##### Epoch: 61/1000 || Loss: 0.5788325071334839
--------------------------------------------------
Epoch: 61 || Loss: 0.5883557001749674 || Accuracy: 0.6948757767677307 || F1-score: 0.40998625744388456


  8%|████████                                                                                             | 80/1000 [01:21<15:49,  1.03s/it]

Epoch: 81/1000 || Item: 0/11 || Loss: 0.6002448201179504


  8%|████████▏                                                                                            | 81/1000 [01:22<15:49,  1.03s/it]

--------------------------------------------------
##### Epoch: 81/1000 || Loss: 0.567249059677124
--------------------------------------------------
Epoch: 81 || Loss: 0.5905295809110006 || Accuracy: 0.6948757767677307 || F1-score: 0.40998625744388456


 10%|██████████                                                                                          | 100/1000 [01:42<15:19,  1.02s/it]

Epoch: 101/1000 || Item: 0/11 || Loss: 0.6132428646087646


 10%|██████████                                                                                          | 101/1000 [01:43<15:22,  1.03s/it]

--------------------------------------------------
##### Epoch: 101/1000 || Loss: 0.5665322542190552
--------------------------------------------------
Epoch: 101 || Loss: 0.5983361899852753 || Accuracy: 0.6948757767677307 || F1-score: 0.40998625744388456


 12%|████████████                                                                                        | 120/1000 [02:02<14:53,  1.02s/it]

Epoch: 121/1000 || Item: 0/11 || Loss: 0.6197059750556946


 12%|████████████                                                                                        | 121/1000 [02:03<14:56,  1.02s/it]

--------------------------------------------------
##### Epoch: 121/1000 || Loss: 0.6034596562385559
--------------------------------------------------
Epoch: 121 || Loss: 0.5835073590278625 || Accuracy: 0.6948757767677307 || F1-score: 0.40998625744388456


 14%|██████████████                                                                                      | 140/1000 [02:23<14:44,  1.03s/it]

Epoch: 141/1000 || Item: 0/11 || Loss: 0.612939715385437


 14%|██████████████                                                                                      | 141/1000 [02:24<14:39,  1.02s/it]

--------------------------------------------------
##### Epoch: 141/1000 || Loss: 0.5837703943252563
--------------------------------------------------
Epoch: 141 || Loss: 0.599367082118988 || Accuracy: 0.6948757767677307 || F1-score: 0.40998625744388456


 16%|████████████████                                                                                    | 160/1000 [02:43<14:22,  1.03s/it]

Epoch: 161/1000 || Item: 0/11 || Loss: 0.6173762679100037


 16%|████████████████                                                                                    | 161/1000 [02:44<14:17,  1.02s/it]

--------------------------------------------------
##### Epoch: 161/1000 || Loss: 0.6611766815185547
--------------------------------------------------
Epoch: 161 || Loss: 0.6174136400222778 || Accuracy: 0.6948757767677307 || F1-score: 0.40998625744388456


 18%|██████████████████                                                                                  | 180/1000 [03:04<14:10,  1.04s/it]

Epoch: 181/1000 || Item: 0/11 || Loss: 0.6334825754165649


 18%|██████████████████                                                                                  | 181/1000 [03:05<14:41,  1.08s/it]

--------------------------------------------------
##### Epoch: 181/1000 || Loss: 0.6217061281204224
--------------------------------------------------
Epoch: 181 || Loss: 0.5959498186906179 || Accuracy: 0.6948757767677307 || F1-score: 0.40998625744388456


 19%|███████████████████▎                                                                                | 193/1000 [03:19<13:52,  1.03s/it]


KeyboardInterrupt: 

Baselines:
   - just looking at the sentence embeddings (encodes nothing about the history on the post)
       - highlights importance of looking at the sequence
   - averaging history
   - comparing the cosine similarity between previous post and current post to see if switch
   
Test for:
- How many posts do you need to look back?

In [180]:
manifesto_data.df["doc_id"].value_counts()

51421_2015    1917
51320_2019    1702
51620_2015    1588
51620_2017    1496
51421_2019    1467
51320_2017    1328
51620_2019    1223
51421_2017    1131
51902_2019    1071
51320_2015    1009
51902_2015     892
51902_2017     786
Name: doc_id, dtype: int64

In [179]:
manifesto_data.df_padded[manifesto_data.df_padded["doc_id"]=="51320_2015"]

Unnamed: 0,timeline_index,d1,d2,d3,d4,d5,d6,d7,d8,d9,d10,doc_id,switched_topic
0,0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,51320_2015,-1
1,1,5.452659,3.602911,8.050925,10.075459,3.911385,9.213820,5.951643,4.275470,1.188564,2.751409,51320_2015,1
2,2,5.111912,3.501168,7.925458,10.427611,3.838215,9.157601,6.385989,4.594971,1.200033,3.107747,51320_2015,0
3,3,5.043742,3.522960,7.468404,9.604556,3.859646,9.537117,6.133909,4.508759,0.930652,2.949600,51320_2015,0
4,4,5.246913,3.351569,7.850750,8.427454,4.078413,9.972346,6.232757,4.770494,1.519472,2.442487,51320_2015,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1004,1004,4.939366,2.744935,6.556887,9.778514,5.032885,9.680021,8.184111,4.553108,1.828841,2.317884,51320_2015,1
1005,1005,3.629394,2.353125,7.404681,10.038898,5.321763,9.898950,7.223835,4.316514,2.549977,2.290533,51320_2015,1
1006,1006,4.761509,3.163177,7.822935,10.507176,4.688953,8.588090,8.575832,4.712253,1.761940,2.563453,51320_2015,1
1007,1007,4.707180,3.207190,7.543126,10.101868,4.077579,9.360678,6.383281,4.666095,1.278698,3.281799,51320_2015,0


In [183]:
manifesto_data.get_torch_path(include_time_features=False).shape

torch.Size([15610, 1, 10])

In [72]:
history_path[3]

array([[0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
        '51320_2015', -1],
       [0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
        '51320_2015', -1],
       [1, 5.452658653259277, 3.6029112339019775, 8.050925254821777,
        10.075458526611328, 3.9113845825195312, 9.213820457458496,
        5.951642990112305, 4.27547025680542, 1.1885640621185303,
        2.751408576965332, '51320_2015', 1],
       [2, 5.111911773681641, 3.5011675357818604, 7.9254584312438965,
        10.427611351013184, 3.838214874267578, 9.157601356506348,
        6.385989189147949, 4.594970703125, 1.2000325918197632,
        3.1077473163604736, '51320_2015', 0],
       [3, 5.043741703033447, 3.5229599475860596, 7.4684038162231445,
        9.6045560836792, 3.8596458435058594, 9.537117004394531,
        6.133909225463867, 4.508759498596191, 0.9306520223617554,
        2.9496004581451416, '51320_2015', 0]], dtype=object)

In [None]:
augmentation_tp = model_specifics["augmentation_tp"]
input_channels = path.shape[2]
output_channels =  [model_specifics["reduced_network_components"]] #13#[10,13]
augmentation_layers = () #[(32, 16, 10)] #(50, 20, output_channels) #
BiLSTM = False
sig_d = 3 
blocks = 3
post_dim = x_data.shape[1]- input_channels
hidden_dim_lstm =  [(12, 8)] #12 [10,12]
hidden_dim = [32] #32 [32,64] 
output_dim = 3
loss = model_specifics["loss_function"] #'focal' #cbfocal
dropout_rate = [0.25]  #0.25 [0.25,0.35]
if (model_specifics['time_injection_history_tp'] == 'timestamp'):
    add_time = True
else: 
    add_time = False

In [None]:
lass DeepSigNet(nn.Module):
    def __init__(
        self,
        input_channels,
        output_channels,
        sig_d,
        post_dim,
        hidden_dim,
        output_dim,
        dropout_rate,
        add_time=False,
        augmentation_tp="Conv1d",
        augmentation_layers=(),
    )