In [90]:
import pandas as pd
import pandas as pd
import torch
import pickle
import os

import nlpsig
from nlpsig.deepsignet import StackedDeepSigNet
from nlpsig.focal_loss import FocalLoss, ClassBalanced_FocalLoss

## Manifesto Project Database

We first create a function to read in data which should be in a `data/` folder in the current directory. The folders have the form `partycode_electiondata.csv` where `partycode` denotes the political party which wrote the manifesto and `electiondate` is the year and month of the election. All of the data was downloaded from the [Manifesto Project Database](https://manifesto-project.wzb.eu/).

Here, we focus on elections in the UK over the last 3 elections (in 05/2015, 06/2017 and 12/2019), and have annotated manifestos from the following parties:
- The Conservative Party (51620)
- The Labour Party (51320)
- Liberal Democrats (51421)
- Scottish National Party (51902)
- Green Party of England and Wales (51110)
- The Party of Wales (51901)
- Democratic Unionist Party (51903)
- We Ourselves (51210)
- Democratic Unionist Party (51903)
- United Kingdom Independence Party (51951)

The coding schemes of how the project annotated each sentence can be found [here](https://manifesto-project.wzb.eu/coding_schemes/mp_v5).

In [91]:
party_dict = {51620: "The Conservative Party",
              51320: "The Labour Party",
              51421: "Liberal Democrats", 
              51902: "Scottish National Party",
              51110: "Green Party of England and Wales",
              51901: "The Party of Wales",
              51903: "Democratic Unionist Party", 
              51210: "We Ourselves",
              51903: "Democratic Unionist Party",
              51951: "United Kingdom Independence Party",
              51340: "Social Democratic and Labour Party",
              51621: "Ulster Unionist Party",
              51430: "Alliance Party of Northern Ireland"}

In [92]:
def concatenate_data(data_folder_path):
    print(f"looking in {data_folder_path} directory...")
    manifesto_dfs = []
    for filename in os.listdir(data_folder_path):
        # parse filename for metadata
        filename_split = filename.split("_")
        party_id = int(filename_split[0])
        year = int(filename_split[1][0:4])
        month = int(filename_split[1][4:6])
        doc_id = f"{party_id}_{year}"
        print(f"- reading {filename} ({party_dict.get(party_id)} in {month}/{year})")
        # read dataframe and add metadata
        df = pd.read_csv(f"data/{filename}")[["text", "cmp_code"]]
        df = df[df["cmp_code"]!="H"].dropna().reset_index(drop=True)
        df["topic"] = [int(str(code)[0]) for code in df["cmp_code"]]
        df["switched_topic"] = [1] + [int(df["topic"].iloc[i]!=df["topic"].iloc[i-1])
                                      for i in range(1, len(df))]
        df["party_id"] = party_id
        df["doc_id"] = f"{party_id}_{year}"
        df["party_name"] = party_dict.get(party_id)
        df["datetime"] = pd.Timestamp(f"{year}-{month}")
        manifesto_dfs.append(df)
    return pd.concat(manifesto_dfs).reset_index(drop=True)

In [93]:
manifesto_df = concatenate_data("data/")

looking in data/ directory...
- reading 51902_201706.csv (Scottish National Party in 6/2017)
- reading 51902_201505.csv (Scottish National Party in 5/2015)
- reading 51901_201912.csv (The Party of Wales in 12/2019)
- reading 51320_201912.csv (The Labour Party in 12/2019)
- reading 51620_201706.csv (The Conservative Party in 6/2017)
- reading 51620_201505.csv (The Conservative Party in 5/2015)
- reading 51110_201505.csv (Green Party of England and Wales in 5/2015)
- reading 51110_201706.csv (Green Party of England and Wales in 6/2017)
- reading 51903_201912.csv (Democratic Unionist Party in 12/2019)
- reading 51210_201912.csv (We Ourselves in 12/2019)
- reading 51421_201706.csv (Liberal Democrats in 6/2017)
- reading 51340_201912.csv (Social Democratic and Labour Party in 12/2019)
- reading 51421_201505.csv (Liberal Democrats in 5/2015)
- reading 51903_201706.csv (Democratic Unionist Party in 6/2017)
- reading 51210_201706.csv (We Ourselves in 6/2017)
- reading 51210_201505.csv (We Ours

From inspecting the data, we have sentences in the `text` column of `manifesto_df` where we have the detailed sentence topics in `cmp_code`, the general summary topic in `topic` and a binary value denoting if a switch of topic has occured in `switched_topic`. We also have the party-id and the corresponding document-id too in `party_id` and `doc_id` respectively. We lastly have the date in which the document was released.

In [94]:
manifesto_df.head(10)

Unnamed: 0,text,cmp_code,topic,switched_topic,party_id,doc_id,party_name,datetime
0,SNP MPs have used their influence to deliver p...,305.1,3,1,51902,51902_2017,Scottish National Party,2017-06-01
1,Here’s just some of what a strong team of SNP ...,305.1,3,0,51902,51902_2017,Scottish National Party,2017-06-01
2,When the Scotland Bill was going through Westm...,301.0,3,0,51902,51902_2017,Scottish National Party,2017-06-01
3,"And it was SNP MPs, working with the Scottish ...",301.0,3,0,51902,51902_2017,Scottish National Party,2017-06-01
4,The SNP secured a deal that ensures Scotland w...,301.0,3,0,51902,51902_2017,Scottish National Party,2017-06-01
5,SNP MPs have consistently opposed Tory austerity.,504.0,5,1,51902,51902_2017,Scottish National Party,2017-06-01
6,Our MPs have been instrumental in forcing UK g...,504.0,5,0,51902,51902_2017,Scottish National Party,2017-06-01
7,Alison Thewliss has been at the forefront of t...,504.0,5,0,51902,51902_2017,Scottish National Party,2017-06-01
8,and force women to prove they have been raped ...,503.0,5,0,51902,51902_2017,Scottish National Party,2017-06-01
9,SNP MPs have worked with Women Against State P...,503.0,5,0,51902,51902_2017,Scottish National Party,2017-06-01


In [96]:
manifesto_df["switched_topic"].value_counts()

0    18914
1     8438
Name: switched_topic, dtype: int64

The topic codes (in `cmp_code`) correspond to an overall domain category and full details of these codes can be found [here](https://manifesto-project.wzb.eu/coding_schemes/mp_v5). The domain categories are as follows:
- Domain 1: External Relations
- Domain 2: Freedom and Democracy
- Domain 3: Political System
- Domain 4: Economy
- Domain 5: Welfare and Quality of Life
- Domain 6: Fabric of Society
- Domain 7: Social Groups

Note that sentences with topic `0` were given no meaningful category in the dataset.

In [97]:
manifesto_df["topic"].value_counts()

5    9482
4    6262
6    2991
1    2792
3    2060
2    2003
7    1700
0      62
Name: topic, dtype: int64

In [98]:
manifesto_df["party_name"].value_counts()

Liberal Democrats                     4515
The Conservative Party                4307
The Labour Party                      4039
Green Party of England and Wales      3640
Scottish National Party               2749
United Kingdom Independence Party     2428
The Party of Wales                    2063
Democratic Unionist Party             1135
Social Democratic and Labour Party     845
Alliance Party of Northern Ireland     622
We Ourselves                           592
Ulster Unionist Party                  417
Name: party_name, dtype: int64

In [99]:
manifesto_df["doc_id"].value_counts()

51110_2015    2233
51421_2015    1917
51320_2019    1702
51620_2015    1588
51620_2017    1496
51421_2019    1467
51951_2015    1348
51320_2017    1328
51620_2019    1223
51110_2019    1198
51421_2017    1131
51951_2017    1080
51902_2019    1071
51320_2015    1009
51901_2019     951
51902_2015     892
51902_2017     786
51901_2015     776
51430_2019     622
51903_2019     473
51340_2019     438
51903_2017     433
51621_2015     417
51340_2015     407
51901_2017     336
51210_2015     272
51903_2015     229
51110_2017     209
51210_2017     193
51210_2019     127
Name: doc_id, dtype: int64

## Obtaining SBERT Embeddings

We can use the `SentenceEncoder` class within `nlpsig` to obtain sentence embeddings from a model. This class uses the [`sentence-transformer`](https://www.sbert.net/docs/package_reference/SentenceTransformer.html) package and here, we have use the pre-trained `all-mpnet-base-v2` model by passing this name as a string to the class - alternative models can be found [here](https://www.sbert.net/docs/pretrained_models.html).

We can pass these into the constructor of the class to initialise our text encoder as follows:

In [101]:
# initialise the Text Encoder 
text_encoder = nlpsig.SentenceEncoder(df=manifesto_df,
                                      feature_name="text",
                                      model_name="all-mpnet-base-v2")
text_encoder.load_pretrained_model()

The class has a `.encode_sentence_transformer()` method which first loads in the model (using the `model_name` and `model_args` attributes) and then obtains an embedding for each sentence. These sentence embeddings are then stored in the `embeddings_sentence` attribute of the object.

In [102]:
text_encoder.obtain_embeddings()
embeddings_sentence = text_encoder.sentence_embeddings

[INFO] number of sentences to encode: 27352


Batches:   0%|          | 0/428 [00:00<?, ?it/s]

## Dimensionality Reduction with UMAP

In [103]:
reduction = nlpsig.DimReduce(method="umap",
                             n_components=50)
embeddings_reduced = reduction.fit_transform(embeddings_sentence)

In [104]:
print(embeddings_sentence.shape)
print(embeddings_reduced.shape)

(27352, 768)
(27352, 50)


## Data preparation: Time injection and Padding

In [105]:
manifesto_data = nlpsig.PrepareData(manifesto_df,
                                    id_column="doc_id",
                                    labels_column="switched_topic",
                                    embeddings=embeddings_sentence,
                                    embeddings_reduced=embeddings_reduced)

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...


In [106]:
manifesto_data.df[manifesto_data.df["doc_id"]=="51902_2017"]

Unnamed: 0,text,cmp_code,topic,switched_topic,party_id,doc_id,party_name,datetime,d1,d2,...,e762,e763,e764,e765,e766,e767,e768,time_encoding,time_diff,timeline_index
21932,SNP MPs have used their influence to deliver p...,305.1,3,1,51902,51902_2017,Scottish National Party,2017-06-01,4.780281,5.795990,...,0.015981,-0.014767,0.026448,0.002084,-0.017059,0.001952,0.015159,2017.413699,0.0,0
21933,Here’s just some of what a strong team of SNP ...,305.1,3,0,51902,51902_2017,Scottish National Party,2017-06-01,4.667715,5.782881,...,-0.012643,-0.033004,0.015802,-0.003156,-0.017330,0.027937,0.013548,2017.413699,0.0,1
21934,When the Scotland Bill was going through Westm...,301,3,0,51902,51902_2017,Scottish National Party,2017-06-01,5.019376,6.432547,...,-0.022350,-0.015041,-0.038259,-0.000943,-0.023851,-0.018221,0.020819,2017.413699,0.0,2
21935,"And it was SNP MPs, working with the Scottish ...",301,3,0,51902,51902_2017,Scottish National Party,2017-06-01,5.053892,5.912344,...,0.029210,0.002856,-0.022373,0.014007,-0.006234,0.015914,0.018991,2017.413699,0.0,3
21936,The SNP secured a deal that ensures Scotland w...,301,3,0,51902,51902_2017,Scottish National Party,2017-06-01,5.530610,6.084226,...,0.031720,-0.014070,0.001398,0.028775,0.021897,0.018755,-0.021968,2017.413699,0.0,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
22713,The disgraceful condition of the housing provi...,201.2,2,0,51902,51902_2017,Scottish National Party,2017-06-01,4.305221,4.830837,...,-0.011745,0.008737,-0.005464,0.014512,-0.022729,-0.041512,0.049706,2017.413699,0.0,781
22714,The Scottish Government’s work to resettle Syr...,201.2,2,0,51902,51902_2017,Scottish National Party,2017-06-01,3.669418,5.439607,...,0.033104,-0.021670,-0.040489,0.022394,-0.003573,0.060122,0.058902,2017.413699,0.0,782
22715,We will urge the UK government to work with th...,301,3,1,51902,51902_2017,Scottish National Party,2017-06-01,3.753327,5.268013,...,0.029256,-0.005607,-0.014617,0.028316,-0.026071,-0.015531,0.037127,2017.413699,0.0,783
22716,rather than use private contractors who have p...,413,4,1,51902,51902_2017,Scottish National Party,2017-06-01,5.596834,4.672209,...,0.021102,0.066897,-0.031050,-0.022594,-0.004096,-0.021181,0.018906,2017.413699,0.0,784


## Obtaining path by looking at post history

We can obtain a path by looking at the history of each post. Here we look at the last 10 posts (and pad with vectors of zeros if there are less than 10 posts) including the current post.

In [121]:
time_features = []
history_path = manifesto_data.pad(pad_by="history",
                                  method="k_last",
                                  zero_padding=True,
                                  k=10,
                                  time_feature=time_features,
                                  embeddings="dim_reduced")

[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.




  0%|                                      | 0/27352 [00:00<?, ?it/s][A[A

  0%|                            | 44/27352 [00:00<01:02, 436.51it/s][A[A

  0%|                           | 106/27352 [00:00<00:50, 541.44it/s][A[A

  1%|▏                          | 166/27352 [00:00<00:47, 567.34it/s][A[A

  1%|▏                          | 224/27352 [00:00<00:47, 571.87it/s][A[A

  1%|▎                          | 282/27352 [00:00<00:47, 569.04it/s][A[A

  1%|▎                          | 339/27352 [00:00<00:47, 566.35it/s][A[A

  1%|▍                          | 396/27352 [00:00<00:48, 551.72it/s][A[A

  2%|▍                          | 452/27352 [00:00<00:49, 542.26it/s][A[A

  2%|▌                          | 507/27352 [00:00<00:50, 532.23it/s][A[A

  2%|▌                          | 561/27352 [00:01<00:52, 515.02it/s][A[A

  2%|▌                          | 613/27352 [00:01<00:53, 503.61it/s][A[A

  2%|▋                          | 664/27352 [00:01<00:54, 493.85it/s][A[

 19%|████▉                     | 5139/27352 [00:11<00:47, 470.34it/s][A[A

 19%|████▉                     | 5187/27352 [00:11<00:47, 468.75it/s][A[A

 19%|████▉                     | 5234/27352 [00:11<00:47, 461.65it/s][A[A

 19%|█████                     | 5290/27352 [00:11<00:45, 489.10it/s][A[A

 20%|█████                     | 5355/27352 [00:11<00:41, 535.03it/s][A[A

 20%|█████▏                    | 5417/27352 [00:11<00:39, 559.14it/s][A[A

 20%|█████▏                    | 5477/27352 [00:11<00:38, 570.93it/s][A[A

 20%|█████▎                    | 5538/27352 [00:11<00:37, 580.19it/s][A[A

 20%|█████▎                    | 5597/27352 [00:11<00:37, 580.14it/s][A[A

 21%|█████▍                    | 5656/27352 [00:11<00:37, 581.11it/s][A[A

 21%|█████▍                    | 5715/27352 [00:12<00:37, 571.72it/s][A[A

 21%|█████▍                    | 5773/27352 [00:12<00:38, 565.18it/s][A[A

 21%|█████▌                    | 5830/27352 [00:12<00:38, 555.44it/s][A[A


 40%|█████████▉               | 10809/27352 [00:22<00:43, 381.37it/s][A[A

 40%|█████████▉               | 10848/27352 [00:22<00:43, 380.60it/s][A[A

 40%|█████████▉               | 10887/27352 [00:22<00:43, 378.92it/s][A[A

 40%|█████████▉               | 10925/27352 [00:22<00:44, 373.06it/s][A[A

 40%|██████████               | 10963/27352 [00:22<00:44, 371.77it/s][A[A

 40%|██████████               | 11001/27352 [00:22<00:45, 360.23it/s][A[A

 40%|██████████               | 11038/27352 [00:22<00:46, 353.17it/s][A[A

 41%|██████████▏              | 11098/27352 [00:22<00:38, 422.95it/s][A[A

 41%|██████████▏              | 11158/27352 [00:23<00:34, 473.60it/s][A[A

 41%|██████████▎              | 11219/27352 [00:23<00:31, 512.76it/s][A[A

 41%|██████████▎              | 11280/27352 [00:23<00:29, 540.07it/s][A[A

 41%|██████████▎              | 11341/27352 [00:23<00:28, 559.71it/s][A[A

 42%|██████████▍              | 11402/27352 [00:23<00:27, 573.38it/s][A[A


 60%|███████████████          | 16499/27352 [00:33<00:19, 547.19it/s][A[A

 61%|███████████████▏         | 16554/27352 [00:33<00:20, 525.98it/s][A[A

 61%|███████████████▏         | 16607/27352 [00:33<00:20, 515.62it/s][A[A

 61%|███████████████▏         | 16659/27352 [00:33<00:20, 514.06it/s][A[A

 61%|███████████████▎         | 16711/27352 [00:33<00:20, 511.64it/s][A[A

 61%|███████████████▎         | 16763/27352 [00:34<00:20, 507.26it/s][A[A

 61%|███████████████▎         | 16814/27352 [00:34<00:20, 503.76it/s][A[A

 62%|███████████████▍         | 16865/27352 [00:34<00:21, 496.07it/s][A[A

 62%|███████████████▍         | 16915/27352 [00:34<00:21, 488.26it/s][A[A

 62%|███████████████▌         | 16964/27352 [00:34<00:21, 481.02it/s][A[A

 62%|███████████████▌         | 17013/27352 [00:34<00:22, 464.82it/s][A[A

 62%|███████████████▌         | 17060/27352 [00:34<00:22, 454.56it/s][A[A

 63%|███████████████▋         | 17106/27352 [00:34<00:23, 444.33it/s][A[A


 82%|████████████████████▌    | 22547/27352 [00:44<00:09, 533.79it/s][A[A

 83%|████████████████████▋    | 22601/27352 [00:44<00:09, 526.36it/s][A[A

 83%|████████████████████▋    | 22654/27352 [00:44<00:08, 522.18it/s][A[A

 83%|████████████████████▊    | 22707/27352 [00:44<00:09, 515.27it/s][A[A

 83%|████████████████████▊    | 22764/27352 [00:45<00:08, 530.65it/s][A[A

 83%|████████████████████▊    | 22829/27352 [00:45<00:08, 563.45it/s][A[A

 84%|████████████████████▉    | 22892/27352 [00:45<00:07, 581.66it/s][A[A

 84%|████████████████████▉    | 22954/27352 [00:45<00:07, 592.33it/s][A[A

 84%|█████████████████████    | 23016/27352 [00:45<00:07, 597.91it/s][A[A

 84%|█████████████████████    | 23076/27352 [00:45<00:07, 594.30it/s][A[A

 85%|█████████████████████▏   | 23136/27352 [00:45<00:07, 592.41it/s][A[A

 85%|█████████████████████▏   | 23196/27352 [00:45<00:07, 584.09it/s][A[A

 85%|█████████████████████▎   | 23255/27352 [00:45<00:07, 576.19it/s][A[A


In [122]:
history_path.shape

(27352, 10, 52)

In [123]:
x_data, input_channels = manifesto_data.get_torch_path_for_SDSN(
    include_time_features_in_path=True,
    include_time_features_in_input=True,
    include_embedding_in_input=False,
    reduced_embeddings=False
)

In [124]:
x_data.shape

torch.Size([27352, 10, 50])

In [125]:
input_channels

50

## StackedDeepSigNet

In [126]:
x_data.shape[2]-input_channels-len(time_features)

0

In [128]:
seed = 2022
SDSN_args = {
    "input_channels": input_channels,
    "output_channels": 10,
    "num_time_features": len(time_features),
    "embedding_dim": x_data.shape[2]-input_channels-len(time_features),
    "sig_depth": 3,
    "hidden_dim_lstm": (12, 8),
    "hidden_dim": 32,
    "output_dim": len(manifesto_data.df["switched_topic"].unique()),
    "dropout_rate": 0.25,
    "augmentation_type": "Conv1d",
    "augmentation_layers": (),
    "blocks": 2,
    "BiLSTM": False,
    "comb_method": "concatenation"
}

In [129]:
data_folds = nlpsig.Folds(x_data=x_data,
                          y_data=torch.tensor(manifesto_data.df["switched_topic"]),
                          n_splits=2,
                          shuffle=True,
                          random_state=seed)

In [130]:
train, valid, test = data_folds.get_splits(fold_index = 1,
                                           as_DataLoader = True,
                                           data_loader_args = {"batch_size": 512,
                                                               "shuffle": True})

In [131]:
batch = next(iter(train))

In [132]:
batch

[tensor([[[4.8575, 4.4757, 3.8674,  ..., 6.7541, 4.9256, 4.6695],
          [3.0103, 4.0449, 3.9233,  ..., 6.8126, 5.0036, 4.5317],
          [3.9453, 5.6912, 4.2700,  ..., 6.6684, 5.1065, 4.5260],
          ...,
          [3.4957, 3.6898, 3.7919,  ..., 6.8552, 4.9413, 4.5899],
          [3.3289, 3.7822, 4.4818,  ..., 6.8304, 4.9950, 4.5095],
          [2.8045, 3.4285, 4.5322,  ..., 6.8266, 5.0292, 4.5486]],
 
         [[5.2815, 5.1535, 3.5655,  ..., 6.8262, 4.9555, 4.6605],
          [6.2355, 4.5312, 3.3463,  ..., 6.6793, 5.0853, 4.5737],
          [5.6686, 4.8074, 3.8321,  ..., 6.8296, 5.0129, 4.5364],
          ...,
          [4.7169, 5.4783, 3.8919,  ..., 6.8896, 4.9896, 4.5181],
          [4.5542, 5.4786, 3.7742,  ..., 6.9192, 4.9676, 4.5535],
          [5.6405, 4.6693, 3.7480,  ..., 6.8188, 5.0235, 4.4820]],
 
         [[3.3951, 4.2210, 3.4105,  ..., 6.8747, 4.9638, 4.5165],
          [2.7397, 3.7160, 3.4474,  ..., 6.8272, 5.0390, 4.5734],
          [2.7316, 4.0132, 3.7262,  ...,

In [133]:
# initial model definitions
model = StackedDeepSigNet(**SDSN_args)

# define loss
loss = "cross_entropy"
num_folds = 5
gamma = 2
beta = 0.999

if loss == "focal":    
    criterion = FocalLoss(gamma = gamma)
    y_train = data_folds.get_splits(fold_index=0)[5]
    criterion.set_alpha_from_y(y=y_train)
elif loss == "cbfocal":
    criterion = ClassBalanced_FocalLoss(gamma = gamma,
                                        beta = beta,
                                        no_of_classes = 2)
    y_train = data_folds.get_splits(fold_index=0)[5]
    criterion.set_samples_per_cls_from_y(y=y_train)
elif loss == "cross_entropy":
    criterion = torch.nn.CrossEntropyLoss()
    
# define optimizer
learning_rate = 0.00005
weight_decay_adam = 0.0001
optimizer = torch.optim.Adam(model.parameters(),
                             lr=learning_rate)

In [134]:
model.eval()
labels_all = torch.empty((0))
predicted_all = torch.empty((0))
with torch.no_grad():
    # Iterate through test dataset
    for emb_t, labels_t in test:
        # make prediction
        outputs_t = model(emb_t)
        
        _, predicted_t = torch.max(outputs_t.data, 1)
        # save predictions and labels
        labels_all = torch.cat([labels_all, labels_t])
        predicted_all = torch.cat([predicted_all, predicted_t])

In [135]:
torch.max(outputs_t, 1)

torch.return_types.max(
values=tensor([0.1046, 0.1049, 0.1045, 0.1042, 0.1053, 0.1051, 0.1043, 0.1050, 0.1050,
        0.1050, 0.1049, 0.1045, 0.1025, 0.1041, 0.1034, 0.1042, 0.1058, 0.1053,
        0.1049, 0.1050, 0.1058, 0.1053, 0.1049, 0.1038, 0.1051, 0.1052, 0.1048,
        0.1044, 0.1050, 0.1049, 0.1051, 0.1056, 0.1054, 0.1059, 0.1050, 0.1046,
        0.1060, 0.1045, 0.1053, 0.1046, 0.1048, 0.1046, 0.1049, 0.1056, 0.1047,
        0.1052, 0.1041, 0.1041, 0.1047, 0.1052, 0.1052, 0.1051, 0.1042, 0.1052,
        0.1052, 0.1053, 0.1039, 0.1055, 0.1058, 0.1049, 0.1051, 0.1042, 0.1047,
        0.1051, 0.1048, 0.1043, 0.1043, 0.1031, 0.1046, 0.1054, 0.1054, 0.1051,
        0.1056, 0.1039, 0.1053, 0.1049, 0.0925, 0.1042, 0.1051, 0.1052, 0.1054,
        0.1047, 0.1055, 0.1046, 0.1048, 0.1058, 0.1046, 0.1045, 0.1057, 0.1046,
        0.1049, 0.1051, 0.1050, 0.1049, 0.1047, 0.1044, 0.1048, 0.1043, 0.1038,
        0.1046, 0.1049, 0.1034, 0.1054, 0.1052, 0.1052, 0.1056, 0.0932, 0.1036,
        0

In [None]:
trained_model = nlpsig.training_pytorch(model = model,
                                        train_loader = train,
                                        valid_loader = valid,
                                        criterion = criterion,
                                        optimizer = optimizer,
                                        num_epochs = 1000,
                                        seed = seed,
                                        patience = 10,
                                        verbose = True,
                                        verbose_epoch = 20)

  0%|                                       | 0/1000 [00:00<?, ?it/s]

Epoch: 1/1000 || Item: 0/18 || Loss: 0.7072409987449646
--------------------------------------------------
##### Epoch: 1/1000 || Loss: 0.7021772265434265
--------------------------------------------------


  0%|                               | 1/1000 [00:01<32:17,  1.94s/it]

Epoch: 1 || Loss: 0.7034725414382087 || Accuracy: 0.3054940104484558 || F1-score: 0.2340064483285254


  2%|▌                             | 20/1000 [00:37<30:18,  1.86s/it]

Epoch: 21/1000 || Item: 0/18 || Loss: 0.6292406320571899
--------------------------------------------------
##### Epoch: 21/1000 || Loss: 0.6217511296272278
--------------------------------------------------


  2%|▋                             | 21/1000 [00:39<30:00,  1.84s/it]

Epoch: 21 || Loss: 0.616389274597168 || Accuracy: 0.6945059895515442 || F1-score: 0.40985749771211927


  4%|█▏                            | 40/1000 [01:13<28:57,  1.81s/it]

Epoch: 41/1000 || Item: 0/18 || Loss: 0.5946145057678223
--------------------------------------------------
##### Epoch: 41/1000 || Loss: 0.5922710299491882
--------------------------------------------------


  4%|█▏                            | 41/1000 [01:15<29:06,  1.82s/it]

Epoch: 41 || Loss: 0.6086574594179789 || Accuracy: 0.6945059895515442 || F1-score: 0.41055803476102043


  6%|█▊                            | 60/1000 [01:51<28:31,  1.82s/it]

Epoch: 61/1000 || Item: 0/18 || Loss: 0.6061823964118958
--------------------------------------------------
##### Epoch: 61/1000 || Loss: 0.6151013970375061
--------------------------------------------------


  6%|█▊                            | 61/1000 [01:53<28:33,  1.82s/it]

Epoch: 61 || Loss: 0.6049279835489061 || Accuracy: 0.6945059895515442 || F1-score: 0.40985749771211927


  8%|██▍                           | 80/1000 [02:28<27:53,  1.82s/it]

Epoch: 81/1000 || Item: 0/18 || Loss: 0.6021712422370911
--------------------------------------------------
##### Epoch: 81/1000 || Loss: 0.6272469758987427
--------------------------------------------------


  8%|██▍                           | 81/1000 [02:29<27:53,  1.82s/it]

Epoch: 81 || Loss: 0.6026279860072665 || Accuracy: 0.6951705813407898 || F1-score: 0.4121890772319115


 10%|██▉                          | 100/1000 [03:04<27:12,  1.81s/it]

Epoch: 101/1000 || Item: 0/18 || Loss: 0.6228774785995483
--------------------------------------------------
##### Epoch: 101/1000 || Loss: 0.6305090188980103
--------------------------------------------------


 10%|██▉                          | 101/1000 [03:06<27:20,  1.82s/it]

Epoch: 101 || Loss: 0.5998154878616333 || Accuracy: 0.6967213153839111 || F1-score: 0.42435327765156233


 12%|███▍                         | 120/1000 [03:40<26:37,  1.82s/it]

Epoch: 121/1000 || Item: 0/18 || Loss: 0.6095082759857178
--------------------------------------------------
##### Epoch: 121/1000 || Loss: 0.6216039657592773
--------------------------------------------------


 12%|███▌                         | 121/1000 [03:42<26:33,  1.81s/it]

Epoch: 121 || Loss: 0.6011377308103774 || Accuracy: 0.6947275400161743 || F1-score: 0.43641766919813796


 14%|████                         | 140/1000 [04:17<26:02,  1.82s/it]

Epoch: 141/1000 || Item: 0/18 || Loss: 0.6173529028892517
--------------------------------------------------
##### Epoch: 141/1000 || Loss: 0.6040025353431702
--------------------------------------------------


 14%|████                         | 141/1000 [04:19<26:00,  1.82s/it]

Epoch: 141 || Loss: 0.5979373852411906 || Accuracy: 0.6956136226654053 || F1-score: 0.4405289172941609


 16%|████▋                        | 160/1000 [04:53<25:24,  1.81s/it]

Epoch: 161/1000 || Item: 0/18 || Loss: 0.5972712635993958
--------------------------------------------------
##### Epoch: 161/1000 || Loss: 0.5919632315635681
--------------------------------------------------


 16%|████▋                        | 161/1000 [04:55<25:52,  1.85s/it]

Epoch: 161 || Loss: 0.5977567632993063 || Accuracy: 0.6960567235946655 || F1-score: 0.445601483836778


 18%|█████▏                       | 180/1000 [05:30<25:02,  1.83s/it]

Epoch: 181/1000 || Item: 0/18 || Loss: 0.5730757117271423
--------------------------------------------------
##### Epoch: 181/1000 || Loss: 0.6022092700004578
--------------------------------------------------


 18%|█████▏                       | 181/1000 [05:32<25:53,  1.90s/it]

Epoch: 181 || Loss: 0.597872363196479 || Accuracy: 0.6938413977622986 || F1-score: 0.4516314592329457


 20%|█████▊                       | 200/1000 [06:14<26:16,  1.97s/it]

Epoch: 201/1000 || Item: 0/18 || Loss: 0.6095460057258606
--------------------------------------------------
##### Epoch: 201/1000 || Loss: 0.5573195815086365
--------------------------------------------------


 20%|█████▊                       | 201/1000 [06:16<25:48,  1.94s/it]

Epoch: 201 || Loss: 0.5946462551752726 || Accuracy: 0.6949490308761597 || F1-score: 0.44865803842447866


 22%|██████▎                      | 217/1000 [06:53<28:41,  2.20s/it]

Baselines:
   - just looking at the sentence embeddings (encodes nothing about the history on the post)
       - highlights importance of looking at the sequence
   - averaging history
   - comparing the cosine similarity between previous post and current post to see if switch
   
Test for:
- How many posts do you need to look back?

In [180]:
manifesto_data.df["doc_id"].value_counts()

51421_2015    1917
51320_2019    1702
51620_2015    1588
51620_2017    1496
51421_2019    1467
51320_2017    1328
51620_2019    1223
51421_2017    1131
51902_2019    1071
51320_2015    1009
51902_2015     892
51902_2017     786
Name: doc_id, dtype: int64

In [179]:
manifesto_data.df_padded[manifesto_data.df_padded["doc_id"]=="51320_2015"]

Unnamed: 0,timeline_index,d1,d2,d3,d4,d5,d6,d7,d8,d9,d10,doc_id,switched_topic
0,0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,51320_2015,-1
1,1,5.452659,3.602911,8.050925,10.075459,3.911385,9.213820,5.951643,4.275470,1.188564,2.751409,51320_2015,1
2,2,5.111912,3.501168,7.925458,10.427611,3.838215,9.157601,6.385989,4.594971,1.200033,3.107747,51320_2015,0
3,3,5.043742,3.522960,7.468404,9.604556,3.859646,9.537117,6.133909,4.508759,0.930652,2.949600,51320_2015,0
4,4,5.246913,3.351569,7.850750,8.427454,4.078413,9.972346,6.232757,4.770494,1.519472,2.442487,51320_2015,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1004,1004,4.939366,2.744935,6.556887,9.778514,5.032885,9.680021,8.184111,4.553108,1.828841,2.317884,51320_2015,1
1005,1005,3.629394,2.353125,7.404681,10.038898,5.321763,9.898950,7.223835,4.316514,2.549977,2.290533,51320_2015,1
1006,1006,4.761509,3.163177,7.822935,10.507176,4.688953,8.588090,8.575832,4.712253,1.761940,2.563453,51320_2015,1
1007,1007,4.707180,3.207190,7.543126,10.101868,4.077579,9.360678,6.383281,4.666095,1.278698,3.281799,51320_2015,0


In [183]:
manifesto_data.get_torch_path(include_time_features=False).shape

torch.Size([15610, 1, 10])

In [72]:
history_path[3]

array([[0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
        '51320_2015', -1],
       [0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
        '51320_2015', -1],
       [1, 5.452658653259277, 3.6029112339019775, 8.050925254821777,
        10.075458526611328, 3.9113845825195312, 9.213820457458496,
        5.951642990112305, 4.27547025680542, 1.1885640621185303,
        2.751408576965332, '51320_2015', 1],
       [2, 5.111911773681641, 3.5011675357818604, 7.9254584312438965,
        10.427611351013184, 3.838214874267578, 9.157601356506348,
        6.385989189147949, 4.594970703125, 1.2000325918197632,
        3.1077473163604736, '51320_2015', 0],
       [3, 5.043741703033447, 3.5229599475860596, 7.4684038162231445,
        9.6045560836792, 3.8596458435058594, 9.537117004394531,
        6.133909225463867, 4.508759498596191, 0.9306520223617554,
        2.9496004581451416, '51320_2015', 0]], dtype=object)

In [None]:
augmentation_tp = model_specifics["augmentation_tp"]
input_channels = path.shape[2]
output_channels =  [model_specifics["reduced_network_components"]] #13#[10,13]
augmentation_layers = () #[(32, 16, 10)] #(50, 20, output_channels) #
BiLSTM = False
sig_d = 3 
blocks = 3
post_dim = x_data.shape[1]- input_channels
hidden_dim_lstm =  [(12, 8)] #12 [10,12]
hidden_dim = [32] #32 [32,64] 
output_dim = 3
loss = model_specifics["loss_function"] #'focal' #cbfocal
dropout_rate = [0.25]  #0.25 [0.25,0.35]
if (model_specifics['time_injection_history_tp'] == 'timestamp'):
    add_time = True
else: 
    add_time = False

In [None]:
lass DeepSigNet(nn.Module):
    def __init__(
        self,
        input_channels,
        output_channels,
        sig_d,
        post_dim,
        hidden_dim,
        output_dim,
        dropout_rate,
        add_time=False,
        augmentation_tp="Conv1d",
        augmentation_layers=(),
    )