In [1]:
import re
from sklearn.model_selection import train_test_split
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import numpy as np
from tqdm import tqdm

In [2]:
df = pd.read_csv("master_final.csv", index_col = 0)

In [3]:
def string_parser(s: str):

    cleaned_s = s.strip().replace("]",'').replace("[",'')

    return [float(item) for item in re.split(r'\s+', cleaned_s) if item]

In [4]:
X = df["contextual_meta_embedding"].apply(string_parser).to_list()
y_exec = df["exec_embedding"].apply(string_parser).to_list()
y_style = df["style_embedding"].apply(string_parser).to_list()
y_content = df["content_embedding"].apply(string_parser).to_list()

In [5]:
(X_train_val, X_test, 
 y_content_train_val, y_content_test, 
 y_exec_train_val, y_exec_test, 
 y_style_train_val, y_style_test) = train_test_split(
    X, y_content, y_exec, y_style, 
    test_size=0.1, 
    random_state=42
)

In [7]:
(X_train, X_val,
 y_content_train, y_content_val,
 y_exec_train, y_exec_val,
 y_style_train, y_style_val) = train_test_split(
    X_train_val, y_content_train_val, y_exec_train_val, y_style_train_val,
    test_size=1/9, 
    random_state=42
)

In [8]:
class PromptDataset(Dataset):

    def __init__(self, context_embeds, content_labels, exec_labels, style_labels):
        self.context_embeds = torch.tensor(context_embeds, dtype = torch.float32)
        self.content_labels = torch.tensor(content_labels, dtype = torch.float32)
        self.exec_labels = torch.tensor(exec_labels, dtype=torch.float32)
        self.style_labels = torch.tensor(style_labels, dtype=torch.float32)

    def __len__(self):
        return len(self.context_embeds)

    def __getitem__(self, idx):
        return (self.context_embeds[idx], self.content_labels[idx], self.exec_labels[idx], self.style_labels[idx])

In [9]:
train_dataset = PromptDataset(X_train, y_content_train, y_exec_train, y_style_train)
val_dataset = PromptDataset(X_val, y_content_val, y_exec_val, y_style_val)
test_dataset = PromptDataset(X_test, y_content_test, y_exec_test, y_style_test)

In [11]:
BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size = BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE)

In [12]:
class PromptDecomposerMLP(nn.Module):
    def __init__(self, input_size, shared_hidden_size, output_embedding_size, dropout_rate = 0.3):
        super(PromptDecomposerMLP, self).__init__()
        self.shared_trunk = nn.Sequential(
            nn.Linear(input_size, shared_hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(shared_hidden_size, shared_hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )
        self.content_head = nn.Linear(shared_hidden_size // 2, output_embedding_size)
        self.execution_head = nn.Linear(shared_hidden_size // 2, output_embedding_size)
        self.style_head = nn.Linear(shared_hidden_size // 2, output_embedding_size)

    def forward(self, x):
        shared_features = self.shared_trunk(x)
        content_out = self.content_head(shared_features)
        exec_out = self.execution_head(shared_features)
        style_out = self.style_head(shared_features)
        return content_out, exec_out, style_out

In [13]:
INPUT_EMBEDDING_SIZE = 768
OUTPUT_EMBEDDING_SIZE = 768
SHARED_HIDDEN_SIZE = 1024
LEARNING_RATE = 1e-4
NUM_EPOCHS = 50

device = "cuda" if torch.cuda.is_available() else "cpu"
model = PromptDecomposerMLP(INPUT_EMBEDDING_SIZE, SHARED_HIDDEN_SIZE, OUTPUT_EMBEDDING_SIZE).to(device)
loss_fn = nn.CosineEmbeddingLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr = LEARNING_RATE)