In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from pytorch_lightning import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset, Subset, random_split

import matplotlib.pyplot as plt
from pacmap import PaCMAP
import numpy as np
import pickle

In [None]:
PACMAP_SIZE = 10000 # Number of samples to use for PaCMAP
NUM_TARGETS = 1 # Number of resource features (y)
TRAIN_SPLIT = 0.8
BATCH_SIZE=1024
EPOCHS=10
FILE="data/libxml2_byteArrays.pkl"

# Model Definitions

In [None]:
class TransformerModel(nn.Module):
    def __init__(self, input_dim, num_heads, num_layers, num_outputs):
        super(TransformerModel, self).__init__()
        self.embed = nn.Embedding(input_dim, input_dim)
        transformer_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads)
        self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=num_layers)
        self.fc = nn.Linear(input_dim, num_outputs)

    def pooling(self, x):
        x = self.embed(x)  # Embedding input
        x = x.permute(1, 0, 2)  # Reshape for transformer
        x = self.transformer_encoder(x)
        x = x.mean(dim=0)  # Pooling
        return x
        
    def forward(self, x):
        x = self.pooling(x)
        x = self.fc(x)
        return x

class ByteSequenceModel(LightningModule):
    def __init__(self):
        super(ByteSequenceModel, self).__init__()
        self.model = TransformerModel(input_dim=256, num_heads=8, num_layers=4, num_outputs=1)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

# Data loading definitions

In [None]:
class ByteSequenceDataset(Dataset):
    def __init__(self, sequences, labels):
        self.sequences = sequences
        self.labels = labels

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx], self.labels[idx]

In [None]:
class ByteSequenceDataModule(LightningModule):
    def __init__(self, train_dataset, val_dataset, batch_size=64):
        super().__init__()
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)


# Load data

In [None]:
with open(FILE, 'rb') as file:
    byte_sequences, labels = pickle.load(file)

In [None]:
int_sequences = [[int(b) for b in byte_seq]  + [0] * (150 - len(byte_seq)) for byte_seq in byte_sequences]

In [None]:
int_sequences_tensor = torch.LongTensor(int_sequences)
labels_tensor = torch.FloatTensor(labels)

In [None]:
dataset = ByteSequenceDataset(int_sequences_tensor, labels_tensor)


In [None]:
# Splitting the dataset
train_size = int(TRAIN_SPLIT * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])


In [None]:
data_module = ByteSequenceDataModule(train_dataset, val_dataset, batch_size=BATCH_SIZE)
model = ByteSequenceModel()

# Train model

In [None]:
trainer = Trainer(max_epochs=EPOCHS) 
trainer.fit(model, datamodule=data_module)

# View contextual layer

In [None]:
def get_embeddings(model, loader):
    model.eval()
    embeddings = []
    with torch.no_grad():
        for batch in loader:
            x, _ = batch
            context = model.model.pooling(x)
            embeddings.append(context)
    return torch.cat(embeddings)

val_loader = DataLoader(Subset(val_dataset, range(0,PACMAP_SIZE)), batch_size=BATCH_SIZE)
embeddings = get_embeddings(model, val_loader)

In [None]:
pacmap_instance = PaCMAP(n_components=2, n_neighbors=10, MN_ratio=0.5, FP_ratio=2.0)
reduced_embeddings = pacmap_instance.fit_transform(embeddings.cpu().numpy())

plt.figure(figsize=(8, 6))
plt.scatter(reduced_embeddings[:, 0], reduced_embeddings[:, 1], alpha=0.7)
plt.title("PaCMAP Visualization of Contextual Embeddings")

plt.show()