In [31]:
import importlib.metadata
import json
import logging
import os
import re
import tempfile
import time
import ast
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
import math
import aeon

import torch
import torch.nn as nn


In [2]:
import os
import numpy as np
import aeon
from aeon.datasets import load_from_tsfile

In [3]:
DATA_PATH = "DATA/"

In [4]:
train_x, train_y = aeon.datasets.load_from_tsfile(DATA_PATH + "Blink_TRAIN.ts")
test_x, test_y = aeon.datasets.load_from_tsfile(DATA_PATH + "Blink_TEST.ts")

train_x, train_y = np.array(train_x), np.array(train_y)
test_x, test_y = np.array(test_x), np.array(test_y)

#reshape from (sample, feat_dim, seq_length) to (seq_length, sample, feat_dim)
train_x, test_x = np.transpose(train_x, (2, 0, 1)), np.transpose(test_x, (2, 0, 1))

# Separate x dimensions into 2 modalities
m1_train_x = train_x[:, :, :2]
m2_train_x = train_x[:, :, 2:]
m1_train_y, m2_train_y = train_y, train_y

#preserve labels
m1_test_x = test_x[:, :, :2]
m2_test_x = test_x[:, :, 2:]
m1_test_y, m2_test_y = test_y, test_y

print(m1_train_x.shape, m1_train_y.shape)

(510, 500, 2) (500,)


In [5]:
class PositionalEncoding(torch.nn.Module):
    r"""
    Implemented from "Language Modeling with nn.Transformer and TorchText" 

    To inject positional information into the embeddings, we use add a embedding based on the mapping of sin/cosine to our original embedding. 
    REMARKS: do we need to add this if our representations already host positional information?

    Args: 
        d_model: dimension of the embeddings, where embedding is shape [n_sample, seq_length, embedding_dim (d_model)]
    """

    def __init__(self, d_model: int, dropout: float = 0.1, seq_len: int = 5000):
        super().__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        position = torch.arange(seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(seq_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[n_sample, seq_length, embedding_dim]``
        """
        print(self.pe.shape)
        print(x.shape)

        #this transformation is for [n_sample, seq_length, embedding_dim]
        x = x + self.pe[:, :x.size(1)]


        #this transformation is for [seq_length n_sample, embedding_dim]
        #x = x + self.pe[:x.size(0)]
        
        return self.dropout(x)

In [6]:
pe_ = PositionalEncoding(d_model = 2, seq_len = 510)

In [7]:
hell0 = pe_(torch.Tensor(m1_train_x))

torch.Size([510, 1, 2])
torch.Size([510, 500, 2])


In [8]:
class cross_attn_block(torch.nn.Module):
    r"""
    Single Block for Cross Attention

    Args: 
        m1: first modality
        m2: second modality

    Shapes: 
        m1: (seq_length, N_samples, N_features)
        m2: (seq_length, N_samples, N_features)

    Returns: 
        embedding of m1 depending on attending on certain elements of m2, multihead_attn(k_m1, v_m1, q_m2)
    """

    def __init__(self, 
                 dim: int, 
                 heads: int, 
                 dropout: float, 
                 seq_length: int):

        super(cross_attn_block, self).__init__()

        self.positional_encoding = PositionalEncoding(dim, dropout, seq_length)

        #learnable
        self._to_key = torch.nn.Linear(dim, dim)
        self._to_query = torch.nn.Linear(dim, dim)
        self._to_value = torch.nn.Linear(dim, dim)

        self.attn = torch.nn.MultiheadAttention(embed_dim = dim, num_heads = heads, dropout = dropout)

    def forward(self, 
                m1: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, 
                m2: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, 
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        
        m1_x = self.positional_encoding(m1)
        m2_x = self.positional_encoding(m2)
        print("passed encoding")

        m1_k = self._to_key(m1_x)
        m1_v = self._to_query(m1_x)
        m2_q = self._to_value(m2_x)
        print("passed kqv")

        #crossing
        cross_x, attn_weights = self.attn(m1_k, m1_v, m2_q)
        print("passed attn:", cross_x.shape)

        return cross_x


class position_wise_ffn(torch.nn.Module):
    r"""
    Position-wise feed-forward network with a RELU activation - essentially contracts output, and squeezes it back to the same space

    ARGS:
        dim: dimension of the embeddings
        hidden_dim: dimension of the inflated hidden layer in feed-forward network
    
    """

    def __init__(self, 
                 dim: int, 
                 hidden_dim: int, 
                 dropout: float = 0.0):
        super(position_wise_ffn, self).__init__()

        self.ffn_1 = torch.nn.Linear(dim, hidden_dim)
        self.ffn_2 = torch.nn.Linear(hidden_dim, dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        x = self.ffn_1(x).relu()
        x = self.ffn_2(x)

        return x


class cross_attn_channel(torch.nn.Module):
    r"""
    Model for Cross Attention, architecture implementation taken from encoder layer of "Attention is all you need"
    Includes multi-head attn with crossing --> add + norm --> positionwise ffn --> add + norm --> output (based on paper)

    ARGS:
        dim_m1: time series modality 1
        dim_m2: time series modality 2

    Shapes:
        assuming seq_length is same for both
        m1: (seq_length, N_samples, N_features)
        m2: (seq_length, N_samples, N_features)
    """

    def __init__(self, 
                 dim_m1: int, 
                 dim_m2: int, 
                 outdim_m1: int, 
                 outdim_m2: int,
                 heads: Optional[int], 
                 seq_len: int, 
                 dropout: float = 0.0):
        super(cross_attn_channel, self).__init__()

        self.m1_cross_m2 = cross_attn_block(dim = dim_m1, heads = heads, dropout = dropout, seq_length = seq_len)
        self.m2_cross_m1 = cross_attn_block(dim = dim_m2, heads = heads, dropout = dropout, seq_length = seq_len)

        self.norm_m1 = torch.nn.LayerNorm(dim_m1)
        self.norm_m2 = torch.nn.LayerNorm(dim_m2)

        self.m1_pffn = position_wise_ffn(dim_m1, 512)
        self.m2_pffn = position_wise_ffn(dim_m2, 512)

        self.norm_pffn_m1 = torch.nn.LayerNorm(dim_m1)
        self.norm_pffn_m2 = torch.nn.LayerNorm(dim_m2)

        self.dropout = torch.nn.Dropout(dropout)



    def forward(self, 
                m1: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, 
                m2: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, 
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:

        m1_x = self.m1_cross_m2(m1, m2)
        m2_x = self.m2_cross_m1(m2, m1)

        m1_x = self.norm_m1(m1 + self.dropout(m1_x))
        m2_x  = self.norm_m2(m2 + self.dropout(m2_x))

        m1_ffn = self.m1_pffn(m1_x)
        m2_ffn = self.m2_pffn(m2_x)

        m1_x = self.norm_pffn_m1(m1_x + self.dropout(m1_ffn))
        m2_x = self.norm_pffn_m2(m2_x + self.dropout(m2_ffn))

        return m1_x, m2_x



In [9]:
print(m2_train_x.shape, m1_train_x.shape)

(510, 500, 2) (510, 500, 2)


In [10]:
cross_attn_channel_ = cross_attn_channel(dim_m1 = m1_train_x.shape[-1], dim_m2 = m2_train_x.shape[-1], outdim_m1=16, outdim_m2=16, heads = 2, seq_len =  m2_train_x.shape[0])

In [16]:
m1, m2 = cross_attn_channel_(torch.Tensor(m1_train_x), torch.Tensor(m2_train_x))

torch.Size([510, 1, 2])
torch.Size([510, 500, 2])
torch.Size([510, 1, 2])
torch.Size([510, 500, 2])
passed encoding
passed kqv
passed attn: torch.Size([510, 500, 2])
torch.Size([510, 1, 2])
torch.Size([510, 500, 2])
torch.Size([510, 1, 2])
torch.Size([510, 500, 2])
passed encoding
passed kqv
passed attn: torch.Size([510, 500, 2])


In [17]:
m1.shape

torch.Size([510, 500, 2])

In [146]:
m2_train_x.shape[-1]

510

In [288]:
encoder_layer = torch.nn.TransformerEncoderLayer(d_model=512, nhead=2)
transformer_encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=2)
src = torch.rand(10, 32, 512)
out = transformer_encoder(src)

classifier = torch.nn.Linear(512, 2)
out = classifier(out)



In [289]:
out.shape

torch.Size([10, 32, 2])

In [14]:
### testing self attention channel 

class self_attn_block(torch.nn.Module):
    r"""
    self attention block 

    Args: 
        dim: Dimension of the embeddings for this modality 
        heads: Number of attention heads
        dropout: Dropout rate
        seq_length: Sequence length

    Shapes: 
        x: (seq_length, N_samples, N_features)

    Returns: 
        Embedding of x after self-attention with same input dimensions
    """

    def __init__(self, 
                 dim: int, 
                 heads: int, 
                 dropout: float, 
                 seq_length: int, 
                  add_positional: Optional[bool] = False):

        super(self_attn_block, self).__init__()

        self.add_positional = add_positional
        self.positional_encoding = PositionalEncoding(dim, dropout, seq_length)

        # learnable linear projections 
        self._to_key = torch.nn.Linear(dim, dim)
        self._to_query = torch.nn.Linear(dim, dim)
        self._to_value = torch.nn.Linear(dim, dim)

        # Multi-head attention layer
        self.attn = torch.nn.MultiheadAttention(embed_dim=dim, num_heads=heads, dropout=dropout)

    def forward(self, 
                x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]], 
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        
        # (optional) positional encoding
        print('positional encoding...')
        if self.add_positional:
            x = self.positional_encoding(x)
        print('done!')

        # project input to q, k, v 
        print('k, q, v')
        k = self._to_key(x)
        q = self._to_query(x)
        v = self._to_value(x)
        print('done!')

        # Self-attention: each element attends to all elements within the sequence
        print('self attn')
        attn_x, attn_weights = self.attn(q, k, v, attn_mask=mask)
        print('done!')

        return attn_x

class self_attn_channel(torch.nn.Module):
    r"""
    Self-Attention Channel Model, based on the architecture of "Attention is All You Need"
    Includes self-attention, add + norm, position-wise FFN, add + norm

    Args:
        dim: Dimension of the embeddings
        pffn_dim: Dimension of hidden layer in position-wise FFN
        heads: Number of attention heads
        seq_len: Length of sequence
        dropout: Dropout rate

    Shapes:
        x: (seq_length, N_samples, N_features)

    Returns:
        Transformed x, same dimension as input with self-attention applied
    """

    def __init__(self, 
                 dim: int, 
                 pffn_dim: int, 
                 heads: Optional[int], 
                 seq_len: int, 
                 dropout: float = 0.0):
        
        super(self_attn_channel, self).__init__()

        # Self-attention block
        self.self_attn = self_attn_block(dim=dim, heads=heads, dropout=dropout, seq_length=seq_len)

        # Layer normalization for self-attention output
        self.norm_self_attn = torch.nn.LayerNorm(dim)

        # Position-wise feed-forward network
        self.pffn = position_wise_ffn(dim, pffn_dim)

        # Layer normalization for FFN output
        self.norm_pffn = torch.nn.LayerNorm(dim)

        # Dropout
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, 
                x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]], 
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:

        # self-attn and add residual connection
        attn_x = self.self_attn(x, mask=mask)
        x = self.norm_self_attn(x + self.dropout(attn_x))

        # position-wise ffn and add residual connection
        ffn_x = self.pffn(x)
        x = self.norm_pffn(x + self.dropout(ffn_x))

        return x

In [20]:
print(m1_train_x.shape)

(510, 500, 2)


In [15]:
# instantiate the self attn module for m1
m1_self_attn_channel = self_attn_channel(dim = m1_train_x.shape[-1], pffn_dim=16, heads = 2, seq_len =  m1_train_x.shape[0])

In [21]:
m1_self_attn = m1_self_attn_channel(torch.Tensor(m1_train_x))

positional encoding...
done!
k, q, v
done!
self attn
done!


In [22]:
m1_self_attn.shape # should match (seq_len, n_samples, n_features)

torch.Size([510, 500, 2])

In [70]:
### test the MMCA model 

class MMCA(torch.nn.Module):
    r"""
    Torch implentation of Multi-Modal Cross Attention
    """
    def __init__(self, 
                #input shapes (Embedding Size, Time Length) 
                m1_shape: Optional[Tuple[int, int]] = None,
                m2_shape: Optional[Tuple[int, int]] = None,
                #modality 1
                m1_self_attn_layers: Optional[int] = None,
                m1_self_attn_heads: Optional[int] = None, 
                m1_cross_attn_layers: Optional[int] = None,
                m1_cross_attn_heads: Optional[int] = None,
                #modality 2
                m2_self_attn_layers: Optional[int] = None,
                m2_self_attn_heads: Optional[int] = None, 
                m2_cross_attn_layers: Optional[int] = None,
                m2_cross_attn_heads: Optional[int] = None,
                #classifier
                classifier: Optional[Any] = None):

        super(MMCA, self).__init__()

        # Multi-modal self-attention layers
        # modality 1 self-attention
        self.m1_self_attn = nn.Sequential(*[
            self_attn_block(m1_shape[1], m1_self_attn_heads, dropout=0.1, seq_length=m1_shape[0]) 
            for _ in range(m1_self_attn_layers)
        ])
        
        # modality 1 cross-attention
        self.m1_cross_attn = nn.Sequential(*[
            cross_attn_block(m1_shape[1], m1_cross_attn_heads, dropout=0.1, seq_length=m1_shape[0]) 
            for _ in range(m1_cross_attn_layers)
        ])
        
        # modality 2 self-attention
        self.m2_self_attn = nn.Sequential(*[
            self_attn_block(m2_shape[1], m2_self_attn_heads, dropout=0.1, seq_length=m2_shape[0]) 
            for _ in range(m2_self_attn_layers)
        ])
        
        # modality 2 cross-attention
        self.m2_cross_attn = nn.Sequential(*[
            cross_attn_block(m2_shape[1], m2_cross_attn_heads, dropout=0.1, seq_length=m2_shape[0]) 
            for _ in range(m2_cross_attn_layers)
        ])
        
        # classification head 
        self.classifier = classifier or nn.Linear(8, 2) # default classifier 

    
    def forward(self, 
            time_series_1: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
            time_series_2: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]) -> torch.Tensor:
        
        
        #1. Run through self_attn channel
        for layer in self.m1_self_attn:
            m1_self_attn_x = layer(time_series_1)
        
        for layer in self.m2_self_attn:
            m2_self_attn_x = layer(time_series_2)
        
        
        #2. Run through cross_attn channel
        m1_cross_attn_x = time_series_1
        m2_cross_attn_x = time_series_2
        for layer in self.m1_cross_attn:
            m1_cross_attn_x = layer(m1_cross_attn_x, time_series_2)
        
        for layer in self.m2_cross_attn:
            m2_cross_attn_x = layer(m2_cross_attn_x, time_series_1)
        
        
        print(f"m1_self_attn_x: {m1_self_attn_x.shape}")
        print(f"m2_self_attn_x: {m2_self_attn_x.shape}")
        print(f"m1_cross_attn_x: {m1_cross_attn_x.shape}")
        print(f"m2_cross_attn_x: {m2_cross_attn_x.shape}")
        
        #3 process (?)
        
        #4. Concatenate the outputs from all channels
        print(f'concatenating channels')
        concatenated_x = torch.cat([m1_self_attn_x, m1_cross_attn_x, m2_self_attn_x, m1_cross_attn_x], dim = -1)
        print(f'done. concatenated_x has shape {concatenated_x.shape}')
        
        #5. process (?)
        # Olivia: I am simply using mean pooling along the seq dimension to 
        # reduce to a single feature vector. 
        print(f'average pooling')
        pooled_concat_X = concatenated_x.mean(dim=0)
        print(f'done. pooled_concat_X has shape {pooled_concat_X.shape}')

        #6. Classify the concatenated output
        x = self.classifier(pooled_concat_X)

        return x, torch.sigmoid(x)

In [71]:
mmca = MMCA(
    m1_shape=[m1_train_x.shape[0], m1_train_x.shape[-1]],  #[leq_len, n_features]
    m2_shape=[m2_train_x.shape[0], m2_train_x.shape[-1]],
    m1_self_attn_layers=2,
    m1_self_attn_heads=2, 
    m1_cross_attn_layers=2,
    m1_cross_attn_heads=2,
    m2_self_attn_layers=2,
    m2_self_attn_heads=2, 
    m2_cross_attn_layers=2,
    m2_cross_attn_heads=2,
)

In [72]:
mmca_out = mmca(torch.Tensor(m1_train_x), torch.Tensor(m2_train_x))

positional encoding...
done!
k, q, v
done!
self attn
done!
positional encoding...
done!
k, q, v
done!
self attn
done!
positional encoding...
done!
k, q, v
done!
self attn
done!
positional encoding...
done!
k, q, v
done!
self attn
done!
torch.Size([510, 1, 2])
torch.Size([510, 500, 2])
torch.Size([510, 1, 2])
torch.Size([510, 500, 2])
passed encoding
passed kqv
passed attn: torch.Size([510, 500, 2])
torch.Size([510, 1, 2])
torch.Size([510, 500, 2])
torch.Size([510, 1, 2])
torch.Size([510, 500, 2])
passed encoding
passed kqv
passed attn: torch.Size([510, 500, 2])
torch.Size([510, 1, 2])
torch.Size([510, 500, 2])
torch.Size([510, 1, 2])
torch.Size([510, 500, 2])
passed encoding
passed kqv
passed attn: torch.Size([510, 500, 2])
torch.Size([510, 1, 2])
torch.Size([510, 500, 2])
torch.Size([510, 1, 2])
torch.Size([510, 500, 2])
passed encoding
passed kqv
passed attn: torch.Size([510, 500, 2])
m1_self_attn_x: torch.Size([510, 500, 2])
m2_self_attn_x: torch.Size([510, 500, 2])
m1_cross_attn_x:

In [75]:
len(mmca_out)

2

In [79]:
mmca_out[0].shape # y_pred 

torch.Size([500, 2])

In [78]:
mmca_out[1].shape # y_pred_class

torch.Size([500, 2])