In [1]:
import torch
import os

os.chdir("c:/Users/cunn2/OneDrive/DSML/Project/thesis-repo")

from sms.exp1.config_classes import load_config_from_launchplan
from sms.exp1.run_training import build_encoder, build_projector
from sms.exp1.models.siamese import SiameseModel

# config = load_config_from_launchplan("sms/exp1/runs/run_20240926_162652/original_launchplan.yaml")

# encoder = build_encoder(config.model_dump())
# projector = build_projector(config.model_dump())

# model = SiameseModel(encoder, projector)

# print(encoder)
# print(projector)
# print(model)

In [2]:
# bert
from torch import nn
import numpy as np

class TokenAndPositionalEmbeddingLayer(nn.Module):
    def __init__(self, input_dim, emb_dim, max_len):
        super().__init__()
        self.max_len = max_len
        self.emb_dim = emb_dim
        self.input_dim = input_dim
        self.token_emb = nn.Conv1d(self.input_dim, self.emb_dim, 1)
        self.pos_emb = self.positional_encoding(self.max_len, self.emb_dim)

    def get_angles(self, pos, i, emb_dim):
        angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(emb_dim))
        return pos * angle_rates

    def positional_encoding(self, position, emb_dim):
        angle_rads = self.get_angles(
            np.arange(position)[:, np.newaxis],
            np.arange(emb_dim)[np.newaxis, :],
            emb_dim,
        )

        angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
        angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
        pos_encoding = angle_rads[np.newaxis, ...]
        return torch.tensor(pos_encoding, dtype=torch.float32)

    def forward(self, x):
        seq_len = x.shape[1]
        x = torch.permute(x, (0, 2, 1))
        x = self.token_emb(x)
        x *= torch.sqrt(torch.tensor(self.emb_dim, dtype=torch.float32))
        x = torch.permute(x, (0, 2, 1))
        return x + self.pos_emb.to(x.device)[:, : x.shape[1]]

class BertEncoder(nn.Module):
    def __init__(self, config, input_shape=2, d_latent=64):
        super(BertEncoder, self).__init__()
        self.d_input = input_shape
        self.d_latent = d_latent
        self.d_model = config.get("d_model", 128)
        self.n_layers = config.get("n_layers", 4)

        self.emb = TokenAndPositionalEmbeddingLayer(
            input_dim=self.d_input, emb_dim=self.d_model, max_len=config.get("max_seq_len", 512)
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.d_model,
            nhead=config.get("n_heads", 8),
            dim_feedforward=config.get("d_ff", self.d_model * 4),
            dropout=config.get("dropout_rate", 0.1),
            batch_first=True,
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, num_layers=self.n_layers
        )
        self.fc = nn.Linear(self.d_model, self.d_latent)
        self.pool = nn.AdaptiveAvgPool1d(1)

    def forward(self, batch):
        # (assuming input batch has shape [batch_size, padded_seq_length, point_dim])
        # batch_key_padding_mask are all False, so the output is the same as batch. This is because all inputs have the same length.
        batch_key_padding_mask = torch.zeros((batch.shape[0], batch.shape[1])).bool()
        batch_key_padding_mask = batch_key_padding_mask.to(batch.device)
        batch_emb = self.emb(batch)             # (batch_size, padded_seq_length, d_model)
        batch_emb = self.transformer_encoder(
            batch_emb, batch_key_padding_mask=batch_key_padding_mask
        )                                       # (batch_size, padded_seq_length, d_model)
        batch_emb = self.fc(batch_emb)          # (batch_size, padded_seq_length, d_latent)
        batch_emb = torch.permute(batch_emb, (0, 2, 1))  # (batch_size, d_latent, padded_seq_length)
        batch_emb = self.pool(batch_emb)            # (batch_size, d_latent, 1)
        batch_emb = torch.squeeze(batch_emb, dim=2)  # (batch_size, d_latent)

        return batch_emb
    
data = torch.load(r"C:\Users\cunn2\OneDrive\DSML\Project\thesis-repo\data\exp1\val_data.pt")
max_length = max([len(chunk) for chunk in data])

dumped_lp_config = {
    "encoder": {
        "type": "BertEncoder",
        "params": {
            "config": {
                "d_model": 128,
                "n_layers": 4,
                "n_heads": 8,
                "d_ff": 512,
                "d_expander": 256,
                "dropout_rate": 0.1,
                "max_seq_len": 512
            }
        }
    },
    "dims": {
        "input_shape": 2,
        "d_latent": 64
    },
    "input": {
        "make_relative_pitch": True,
        "normalize_octave": False,
        "piano_roll": False,
        "quantize": False,
        "rest_pitch": -1,
        "steps_per_bar": 32,
        "pad_sequence": True,
        "pad_val": -1000,
        "goal_seq_len": max_length
    }
}

encoder = build_encoder(dumped_lp_config)


  data = torch.load(r"C:\Users\cunn2\OneDrive\DSML\Project\thesis-repo\data\exp1\val_data.pt")


In [3]:
encoder

BertEncoder(
  (emb): TokenAndPositionalEmbeddingLayer(
    (token_emb): Conv1d(2, 128, kernel_size=(1,), stride=(1,))
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=512, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (fc): Linear(in_features=128, out_features=64, bias=True)
  (pool): AdaptiveAvgPool1d(output_size=1)
)

In [4]:
from sms.src.synthetic_data.formatter import InputFormatter

data_ex = data[:10]

formatter = InputFormatter(**dumped_lp_config['input'])
formatted_data_list = [torch.from_numpy(formatter(chunk).astype(np.float32).copy()) for chunk in data_ex]
formatted_data_stacked = torch.stack(formatted_data_list, dim=0) # shape [num_chunks, *input_shape]

In [5]:
encoder(formatted_data_stacked)

tensor([[-6.3793e-01,  3.5387e-01,  1.0780e-02,  1.2166e+00,  4.6569e-01,
          3.8973e-01, -6.7486e-01,  1.0874e-02,  2.3693e-01,  1.0819e-02,
         -9.4808e-01,  4.6273e-01, -2.2738e-01, -6.4120e-01, -2.5948e-01,
          3.9248e-01, -1.1037e+00,  4.0670e-01,  2.2948e-01, -2.9387e-01,
          7.1413e-01, -9.0243e-01, -4.2870e-02,  6.1272e-01, -3.8800e-01,
          5.6685e-01, -5.9928e-01, -2.3420e-01,  2.5666e-01, -6.5252e-01,
         -4.7824e-01, -2.6841e-02, -2.1075e-01,  1.1169e+00, -7.0040e-01,
         -6.6192e-01,  2.1128e-01, -5.6558e-01, -1.2892e-03, -1.0727e+00,
         -1.7289e-01, -1.8228e-01,  8.1735e-01,  1.1020e-01,  4.5572e-01,
         -2.2616e-01, -6.2268e-01, -1.3244e-01, -3.4373e-01,  4.2239e-01,
          1.5603e-01, -1.3300e-01, -2.6821e-01,  3.8612e-01, -2.0867e-01,
         -1.9862e-01,  2.0648e-01,  8.5957e-01, -2.5489e-01,  9.9180e-01,
         -3.7443e-01, -1.8460e-04,  2.0290e-01,  3.9178e-01],
        [-9.3135e-01, -2.2460e-02, -9.4327e-02,  1

In [6]:
formatter = InputFormatter(pad_sequence=True)

print(data[0])
print(formatter(data[0]))


[[ 0.2 71. ]
 [ 3.  69. ]
 [ 0.8 74. ]]
[[ 2.0e-01  7.1e+01]
 [ 3.0e+00  6.9e+01]
 [ 8.0e-01  7.4e+01]
 [-1.0e+03 -1.0e+03]
 [-1.0e+03 -1.0e+03]
 [-1.0e+03 -1.0e+03]
 [-1.0e+03 -1.0e+03]
 [-1.0e+03 -1.0e+03]
 [-1.0e+03 -1.0e+03]
 [-1.0e+03 -1.0e+03]
 [-1.0e+03 -1.0e+03]
 [-1.0e+03 -1.0e+03]]
