In [15]:
import torch
import os

os.chdir("c:/Users/cunn2/OneDrive/DSML/Project/thesis-repo")

from sms.exp1.config_classes import load_config_from_launchplan
from sms.exp1.run_training import build_encoder, build_projector
from sms.exp1.models.siamese import SiameseModel

config = load_config_from_launchplan("sms/exp1/runs/run_20240926_162652/original_launchplan.yaml")

encoder = build_encoder(config.model_dump())
projector = build_projector(config.model_dump())

model = SiameseModel(encoder, projector)

print(encoder)
print(projector)
print(model)


PianoRollConvEncoder(
  (conv_layers): Sequential(
    (0): Conv2d(1, 2, kernel_size=(10, 10), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(2, 4, kernel_size=(6, 6), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
  )
  (fc): Linear(in_features=20768, out_features=64, bias=True)
)
ProjectionHead(
  (projector): Sequential(
    (0): Linear(in_features=64, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=64, bias=True)
    (5): ReLU()
  )
)
SiameseModel(
  (encoder): PianoRollConvEncoder(
    (conv_layers): S

In [16]:
pt_encoder = build_encoder(config.model_dump())
pt_encoder.load_state_dict(torch.load("sms/exp1/runs/run_20240926_162652/pretrain_saved_model.pth"))    

ft_encoder = build_encoder(config.model_dump())
ft_encoder.load_state_dict(torch.load("sms/exp1/runs/run_20240926_162652/finetune_saved_model.pth"))


  pt_encoder.load_state_dict(torch.load("sms/exp1/runs/run_20240926_162652/pretrain_saved_model.pth"))
  ft_encoder.load_state_dict(torch.load("sms/exp1/runs/run_20240926_162652/finetune_saved_model.pth"))


<All keys matched successfully>

In [17]:
data = torch.load(r"C:\Users\cunn2\OneDrive\DSML\Project\thesis-repo\data\exp1\train_data.pt")

  data = torch.load(r"C:\Users\cunn2\OneDrive\DSML\Project\thesis-repo\data\exp1\train_data.pt")


In [18]:
data[0]

array([[ 0.2, 67. ],
       [ 1. , 74. ],
       [ 2. , 76. ],
       [ 0.8, 74. ]])

In [19]:
from sms.src.synthetic_data.formatter import InputFormatter
from sms.src.synthetic_data.note_arr_mod import NoteArrayModifier
import numpy as np
import logging
from sms.src.log import configure_logging

logger = logging.getLogger(__name__)
configure_logging(console_level=logging.DEBUG)

formatter = InputFormatter(**config.model_dump()['input'])

aug_dict = {
    "use_transposition": True,
    "use_shift_selected_notes_pitch": False,
    "use_change_note_durations": False,
    "use_delete_notes": False,
    "use_insert_notes": False
}

modifier = NoteArrayModifier()

def format_data(data: np.ndarray):
    return formatter(data).astype(np.float32).copy()

anchor = data[0]
pos = modifier(anchor, aug_dict)
neg = data[18]
print(anchor)
print(pos)
print(neg)

anchor = format_data(anchor)
pos = format_data(pos)
neg = format_data(data[17])

anchor_enc = ft_encoder((torch.from_numpy(anchor)).unsqueeze(0))[0].detach().numpy()   
pos_enc = ft_encoder((torch.from_numpy(pos)).unsqueeze(0))[0].detach().numpy()
neg_enc = ft_encoder((torch.from_numpy(neg)).unsqueeze(0))[0].detach().numpy()

print(f'pos distance: {np.linalg.norm(anchor_enc - pos_enc)}')
print(f'neg distance: {np.linalg.norm(anchor_enc - neg_enc)}')


[2024-09-29 15:12:42] [DEBUG] Transposing non-rest notes by 5 semitones.


[[ 0.2 67. ]
 [ 1.  74. ]
 [ 2.  76. ]
 [ 0.8 74. ]]
tensor([[ 0.2000, 72.0000],
        [ 1.0000, 79.0000],
        [ 2.0000, 81.0000],
        [ 0.8000, 79.0000]], dtype=torch.float64)
[[ 0.75 67.  ]
 [ 0.25 69.  ]
 [ 1.   71.  ]
 [ 0.5  67.  ]
 [ 0.5  67.  ]
 [ 1.   69.  ]]
pos distance: 22.374937057495117
neg distance: 20.81243324279785


In [20]:
from typing import List

def format_data_for_conv_enc(data: np.ndarray, formatter: InputFormatter):
    return torch.from_numpy(formatter(data).astype(np.float32).copy())

def format_dataset_for_conv_enc(dataset: List[np.ndarray]):
    formatted_data = [format_data_for_conv_enc(data, formatter) for data in dataset]
    return torch.stack(formatted_data, dim=0)

data_formatted = format_dataset_for_conv_enc(data)


In [21]:
print(data_formatted.shape)
embeddings = ft_encoder(data_formatted)


torch.Size([14631, 128, 32])


In [22]:
embeddings = embeddings.detach()

In [23]:
embeddings.shape

torch.Size([14631, 64])

In [24]:
embeddings[0]

tensor([ 9.8189e+00, -1.0568e+01,  1.3061e+00, -3.7155e+00,  2.3748e+00,
        -9.1931e+00,  4.0135e+00,  7.2098e+00,  1.6388e+00,  1.6952e+01,
         3.5173e+00, -3.6373e+00,  3.9732e+00,  1.5120e+01,  1.0261e+00,
         1.1293e+01,  6.1510e+00, -4.7417e+00, -4.8012e+00, -1.4328e+01,
        -2.8075e+00, -3.0592e+00, -7.0917e+00, -7.3687e+00,  1.3666e+00,
        -2.9715e+00,  4.9831e+00,  5.1463e+00, -3.3394e+00, -6.1344e+00,
        -6.1643e+00, -1.6928e+01, -2.0705e+00,  7.5320e-01, -7.7485e+00,
        -9.8449e+00,  4.6780e+00,  3.3799e+00,  4.8790e+00, -1.0553e-02,
         2.3219e+00,  1.0051e+01, -8.1662e+00,  1.1222e+01,  1.2169e+00,
         2.3066e+00, -1.3469e+01,  4.2552e-01, -3.1998e+00, -1.3172e+00,
        -1.8167e+00,  6.6501e+00,  8.4608e-01,  1.6913e+00, -3.8155e+00,
        -4.9456e+00, -9.2409e+00,  9.4063e+00,  9.8553e+00,  3.5024e+00,
        -1.2613e+01, -2.1675e+01, -5.5904e+00, -1.0875e+00])

## evaluation

In [34]:
import faiss
import numpy as np
from typing import Dict, Any

class CustomFAISSIndex:
    def __init__(self, index_type: str, index_args: List[Any] = [], index_kwargs: Dict[str, Any] = {}):
        self.index = getattr(faiss, index_type)(*index_args, **index_kwargs)
        self.id_to_index = {}  # Maps custom IDs to FAISS indices
        self.index_to_id = {}  # Maps FAISS indices to custom IDs
        self.id_to_data = {}   # Maps custom IDs to original data

    def add_with_id(self, id, vector, original_data=None):
        if id in self.id_to_index:
            raise ValueError(f"ID {id} already exists in the index")
        
        index = self.index.ntotal
        self.index.add(np.array([vector], dtype=np.float32))
        self.id_to_index[id] = index
        self.index_to_id[index] = id
        if original_data is not None:
            self.id_to_data[id] = original_data

    def remove(self, id):
        if id not in self.id_to_index:
            raise ValueError(f"ID {id} not found in the index")
        
        index_to_remove = self.id_to_index[id]
        self.index.remove_ids(np.array([index_to_remove]))
        
        # Update mappings
        del self.index_to_id[index_to_remove]
        del self.id_to_index[id]
        if id in self.id_to_data:
            del self.id_to_data[id]
        
        # # Update remaining indices
        # for i in range(index_to_remove, self.index.ntotal):
        #     old_id = self.index_to_id[i + 1]
        #     self.index_to_id[i] = old_id
        #     self.id_to_index[old_id] = i
        # del self.index_to_id[self.index.ntotal]

        # Update remaining indices
        for i in range(index_to_remove, self.index.ntotal):
            if i + 1 in self.index_to_id:
                old_id = self.index_to_id[i + 1]
                self.index_to_id[i] = old_id
                self.id_to_index[old_id] = i
        
        # Remove the last entry if it exists
        if self.index.ntotal in self.index_to_id:
            del self.index_to_id[self.index.ntotal]

    def search(self, query_vector, k,):
        distances, indices = self.index.search(np.array([query_vector], dtype=np.float32), k)
        results = []
        for idx in indices[0]:
            if idx != -1 and idx in self.index_to_id:
                id = self.index_to_id[idx]
                results.append((id, self.id_to_data.get(id)))
        return results

    def get_vector(self, id):
        if id not in self.id_to_index:
            raise ValueError(f"ID {id} not found in the index")
        index = self.id_to_index[id]
        return self.index.reconstruct(index)

    def get_original_data(self, id):
        return self.id_to_data.get(id)
    
    def get_all_items(self, limit=3):
        items = []
        for id in list(self.id_to_data.keys())[:limit]:  # Limit the number of items
            vector = self.get_vector(id)
            original_data = self.get_original_data(id)
            items.append((id, vector, original_data))
        return items

    def __repr__(self):
        items = self.get_all_items(limit=3)  # Limit to 3 items
        total_items = self.index.ntotal
        repr_str = f"CustomFAISSIndex with {total_items} items:\n"
        for id, vector, original_data in items:
            repr_str += f"  ID: {id}\n"
            repr_str += f"    Vector: {vector}\n"
            repr_str += f"    Original Data: {original_data}\n"
        if total_items > 3:
            repr_str += f"  ... and {total_items - 3} more items\n"
        return repr_str



In [35]:
# Test: Add 5 items, add a sixth, then remove the sixth

import numpy as np
from uuid import uuid4

# Create a test index
test_index = CustomFAISSIndex(index_type="IndexLSH", index_args=[64, 256])

# Generate 5 random items
for _ in range(5):
    item_id = str(uuid4())
    vector = np.random.rand(64).astype(np.float32)
    data = f"Test data for {item_id}"
    test_index.add_with_id(item_id, vector, data)

# Verify 5 items are in the index
assert test_index.index.ntotal == 5, f"Expected 5 items, but found {test_index.index.ntotal}"

# Add a sixth item
sixth_id = str(uuid4())
sixth_vector = np.random.rand(64).astype(np.float32)
sixth_data = "Sixth item data"
test_index.add_with_id(sixth_id, sixth_vector, sixth_data)

# Verify 6 items are in the index
assert test_index.index.ntotal == 6, f"Expected 6 items, but found {test_index.index.ntotal}"

# Remove the sixth item
test_index.remove(sixth_id)

# Verify back to 5 items in the index
assert test_index.index.ntotal == 5, f"Expected 5 items after removal, but found {test_index.index.ntotal}"

# Try to access the removed item (should raise an error)
try:
    test_index.get_vector(sixth_id)
    raise AssertionError("Expected ValueError when accessing removed item")
except ValueError:
    pass

print("All tests passed successfully!")


All tests passed successfully!


In [26]:
from uuid import uuid4

data_ids = [str(uuid4()) for _ in range(len(data))]
data_dict = dict(zip(data_ids, data))
embeddings_dict = dict(zip(data_ids, embeddings.detach().numpy()))

dim = list(embeddings_dict.values())[0].shape[0]
embedding_index = CustomFAISSIndex(index_type="IndexLSH", index_args=[dim, 256])
for key, value in embeddings_dict.items():
    embedding_index.add_with_id(key, value, data_dict[key])

### checks

In [27]:
embedding_index.get_original_data(data_ids[0])


array([[ 0.2, 67. ],
       [ 1. , 74. ],
       [ 2. , 76. ],
       [ 0.8, 74. ]])

In [28]:
# # Check 1: Verify all documents are added
# print("Check 1: Verify all documents are added")
# for doc_id in ["doc1", "doc2", "doc3"]:
#     vector = embedding_index.get_vector(doc_id)
#     data = embedding_index.get_original_data(doc_id)
#     print(f"{doc_id}: Vector = {vector}, Data = {data}")

# # Check 2: Remove a document and verify it's gone
# print("\nCheck 2: Remove a document and verify it's gone")
# embedding_index.remove("doc2")
# try:
#     embedding_index.get_vector("doc2")
# except ValueError as e:
#     print(f"Expected error: {e}")

# # Check 3: Verify remaining documents are still accessible
# print("\nCheck 3: Verify remaining documents are still accessible")
# for doc_id in ["doc1", "doc3"]:
#     vector = embedding_index.get_vector(doc_id)
#     data = embedding_index.get_original_data(doc_id)
#     print(f"{doc_id}: Vector = {vector}, Data = {data}")

# # Check 4: Add a new document and verify it's added correctly
# print("\nCheck 4: Add a new document and verify it's added correctly")
# embedding_index.add_with_id("doc4", np.array([4] * dim), 4)
# embedding_index.add_with_id("doc5", np.array([5] * dim), 5)
# vector = embedding_index.get_vector("doc4")
# data = embedding_index.get_original_data("doc4")
# print(f"doc4: Vector = {vector}, Data = {data}")

# # Check 5: Perform a search and verify results
# print("\nCheck 5: Perform a search and verify results")
# query_vector = np.array([2.5] * dim)
# results = embedding_index.search(query_vector, k=2)
# print(f"Search results for query {query_vector}:")
# for id, data in results:
#     print(f"ID: {id}, Data: {data}")

# # Check 6: Try to add a document with an existing ID (should raise an error)
# print("\nCheck 6: Try to add a document with an existing ID")
# try:
#     embedding_index.add_with_id("doc1", np.array([5] * dim), 5)
# except ValueError as e:
#     print(f"Expected error: {e}")

# # Check 7: Try to remove a non-existent document (should raise an error)
# print("\nCheck 7: Try to remove a non-existent document")
# try:
#     embedding_index.remove("doc5")
# except ValueError as e:
#     print(f"Expected error: {e}")

# exp1 eval loop


In [29]:
# produce vector embeddings
from uuid import uuid4
import torch
import logging
import numpy as np
from typing import Callable, Optional, List, Dict
from sms.src.synthetic_data.formatter import InputFormatter
from sms.src.synthetic_data.note_arr_mod import NoteArrayModifier

from sms.exp1.run_training import build_encoder, build_projector
from sms.exp1.models.siamese import SiameseModel

logger = logging.getLogger(__name__)

def augment_chunk(chunk: np.ndarray, augmentation: str):
    """ 
    augmentation is one of the following:
        use_transposition
        use_shift_selected_notes_pitch
        use_change_note_durations
        use_delete_notes
        use_insert_notes
    """
    aug_dict = {
        "use_transposition": False,
        "use_shift_selected_notes_pitch": False,
        "use_change_note_durations": False,
        "use_delete_notes": False,
        "use_insert_notes": False
    }
    aug_dict[augmentation] = True
    modifier = NoteArrayModifier()
    return modifier(chunk, aug_dict)

def create_augmented_data(data_dict: Dict[str, np.ndarray], anchor_keys: List[str]) -> Dict[str, Dict[str, np.ndarray]]:
    """
    Create the augmented data for the given anchor keys.
    Returns a dictionary of dictionaries, where the outer dictionary is keyed by the anchor keys, and the inner dictionary 
        is keyed by the type of augmentation and contains the augmented data.
    """
    augmented_data = {}
    for key in anchor_keys:
        chunk = data_dict[key]
        augmented_data[key] = {
            "chunk_transposed": augment_chunk(chunk, "use_transposition"),
            "chunk_one_pitch_shifted": augment_chunk(chunk, "use_shift_selected_notes_pitch"),
            "chunk_note_duration_changed": augment_chunk(chunk, "use_change_note_durations"),
            "chunk_note_deleted": augment_chunk(chunk, "use_delete_notes"),
            "chunk_note_inserted": augment_chunk(chunk, "use_insert_notes")
        }
    return augmented_data

def build_model(dumped_lp_config: Dict[str, Any], full_model_path: Optional[str] = None, encoder_path: Optional[str] = None, use_full_model: bool = False):
    """
    Only one of full_model_path or encoder_path should be provided. If both are provided, full_model_path is used.
    """
    encoder = build_encoder(dumped_lp_config)
    projector = build_projector(dumped_lp_config)
    model = SiameseModel(encoder, projector)
    if full_model_path is not None:
        model.load_state_dict(torch.load(full_model_path))
    elif encoder_path is not None:
        model = model.get_encoder()
        model.load_state_dict(torch.load(encoder_path))
    else:
        raise ValueError("Either full_model_path or encoder_path must be provided.")
    if not use_full_model and full_model_path is not None:
        model = model.get_encoder()
    return model

def create_embedding_dict(data_dict: Dict[str, np.ndarray], dumped_lp_config: Dict[str, Any], model: Callable) -> Dict[str, np.ndarray]:
    """
    Create the embedding dictionary for the given model. The dumped_lp_config is used to determine the input format of the model.
    Returns the data_dict, but with embeddings instead of the original data.
    """
    formatter = InputFormatter(**dumped_lp_config['input'])
    formatted_data_list = [torch.from_numpy(formatter(chunk).astype(np.float32).copy()) for chunk in data_dict.values()]
    formatted_data_stacked = torch.stack(formatted_data_list, dim=0) # shape [num_chunks, *input_shape]
    embeddings_stacked = model(formatted_data_stacked)
    embeddings_dict = {key: embeddings_stacked[i].detach().numpy() for i, key in enumerate(data_dict.keys())}
    return embeddings_dict

def embeddings_to_faiss_index(
        embeddings_dict: Dict[str, np.ndarray], 
        index_type: str, 
        index_args: List[Any] = [], 
        index_kwargs: Dict[str, Any] = {}
    ) -> CustomFAISSIndex:

    embedding_index = CustomFAISSIndex(index_type=index_type, index_args=index_args, index_kwargs=index_kwargs)
    for key, value in embeddings_dict.items():
        embedding_index.add_with_id(key, value)
    return embedding_index

    # For each embedding collection in embeddings_dicts, we perform the augmentation evaluation experiment num_loops times.
    # An augmentation evaluation experiment involves the following steps:
    # - Randomly select an anchor from data_dict
    # - Remove the anchor from data_dict
    # - Apply each of the five given augmentations to the anchor
    # - For each of the augmented melodies, add it to the database and perform a nearest neighbor search on the FAISS index
    # - Calculate the precision and recall of the search for each k in k_list

def evaluate_top_k(
        embedding_dict: Dict[str, Dict[str, np.ndarray]],
        augmented_embedding_dict: Dict[str, Dict[str, np.ndarray]], 
        k_list: List[int], 
        index: CustomFAISSIndex
    ) -> Dict[str, Dict[str, Dict[str, List[float]]]]:
    """
    index is a CustomFAISSIndex object which has been initialized with the embeddings_dict.
    For each of the keys in augment_dict, we perform the following steps:
    - Remove the anchor (embedding_dict[key]) from the index
    - Add one of the augmentations from that key to the index
    - Perform a nearest neighbor search on the index using the anchor and record the position of the augmentation
    - Repeat for each augmentation
    
    Then we report the average precision and recall for each k in k_list.
    
    Args:
        embeddings_dict: dictionary of embeddings, keyed by data ids
        augmented_embedding_dict: dictionary keyed by a subset of the ids in embeddings_dict, containing dictionaries of augmented data
        k_list: list of k values to evaluate
        index: CustomFAISSIndex object which has been initialized with the embeddings_dict

    Returns:
        results: dictionary of precision and recall for each augmentation and k in k_list
    """
    results = {aug_type: {k: {'precision': [], 'recall': []} for k in k_list} for aug_type in augmented_embedding_dict[list(augmented_embedding_dict.keys())[0]].keys()}
    
    for anchor_id, augmentations in augmented_embedding_dict.items():
        anchor_embedding = embedding_dict[anchor_id]
        
        # remove anchor from index
        index.remove(anchor_id)
        
        for aug_type, augmented_data in augmentations.items():
            # add augmented data to index
            aug_id = f"{anchor_id}_aug_{aug_type}"
            index.add_with_id(aug_id, augmented_data)
            
            # perform search
            search_results = index.search(anchor_embedding, max(k_list))
            
            # calculate precision and recall for each k
            for k in k_list:
                top_k_results = search_results[:k]
                true_positives = sum(1 for id, _ in top_k_results if id == aug_id)
                
                precision = true_positives / k
                recall = 1 if true_positives > 0 else 0  # Recall is 1 if found, 0 if not
                
                results[aug_type][k]['precision'].append(precision)
                results[aug_type][k]['recall'].append(recall)
            
            # remove augmented data from index
            index.remove(aug_id)
        
        # add anchor back to index
        index.add_with_id(anchor_id, anchor_embedding)
    
    # Calculate average precision and recall
    for aug_type in results:
        for k in k_list:
            results[aug_type][k]['avg_precision'] = np.mean(results[aug_type][k]['precision'])
            results[aug_type][k]['avg_recall'] = np.mean(results[aug_type][k]['recall'])
    
    return results

In [30]:
# data_ids = [str(uuid4()) for _ in range(len(data))]
# data_dict = dict(zip(data_ids, data))
# filtered_data = [arr for arr in data if arr.shape[0] > 2]
# filtered_data_ids = [str(uuid4()) for _ in range(len(filtered_data))]
# filtered_data_dict = dict(zip(filtered_data_ids, filtered_data))

# num_loops = 100
# anchor_keys = np.random.choice(filtered_data_ids, size=num_loops, replace=False)
# augmented_data = create_augmented_data(filtered_data_dict, anchor_keys)

# cfg = load_config_from_launchplan("sms/exp1/runs/run_20240926_162652/original_launchplan.yaml")
# encoder = build_encoder(cfg.model_dump())
# encoder.load_state_dict(torch.load("sms/exp1/runs/run_20240926_162652/pretrain_saved_model.pth"))

# create_embedding_dict(augmented_data, cfg.model_dump(), encoder)


[2024-09-29 15:12:51] [DEBUG] Transposing non-rest notes by -6 semitones.
[2024-09-29 15:12:51] [DEBUG] Shifting note at index 1 by -3 semitones.
[2024-09-29 15:12:51] [DEBUG] Scaling duration of note at index 0 by a factor of 1.5.
[2024-09-29 15:12:51] [DEBUG] Truncated note 3 by 0.2999999999999998 to maintain total duration.
[2024-09-29 15:12:51] [DEBUG] Deleting notes at indices [3].
[2024-09-29 15:12:51] [DEBUG] Elongated last note by 1.4 to maintain total duration.
[2024-09-29 15:12:51] [DEBUG] Inserting note at index 2 with duration 0.5 and relative pitch -4.
[2024-09-29 15:12:51] [DEBUG] Truncated note 4 by 0.5 to maintain total duration.
[2024-09-29 15:12:51] [DEBUG] Transposing non-rest notes by 14 semitones.
[2024-09-29 15:12:51] [DEBUG] Shifting note at index 6 by 1 semitones.
[2024-09-29 15:12:51] [DEBUG] Scaling duration of note at index 1 by a factor of 3.0.
[2024-09-29 15:12:51] [DEBUG] Removed note 7 with duration 0.65 to adjust total duration.
[2024-09-29 15:12:51] [DE

IndexError: too many indices for array: array is 0-dimensional, but 2 were indexed

In [None]:
flattened_augmented_data = []
for key1, dict1 in augmented_data.items():
    for key2, arr in dict1.items():
        flattened_augmented_data.append(arr)


In [None]:
flattened_augmented_data

[tensor([[ 0.6000, 88.0000],
         [ 0.5000, 88.0000],
         [ 0.5000, 87.0000],
         [ 0.5000, 85.0000],
         [ 0.5000, 84.0000],
         [ 1.0000, 85.0000],
         [ 0.4000, 85.0000]], dtype=torch.float64),
 tensor([[ 0.6000, 75.0000],
         [ 0.5000, 75.0000],
         [ 0.5000, 76.0000],
         [ 0.5000, 72.0000],
         [ 0.5000, 71.0000],
         [ 1.0000, 72.0000],
         [ 0.4000, 72.0000]], dtype=torch.float64),
 tensor([[ 0.6000, 75.0000],
         [ 0.5000, 75.0000],
         [ 0.5000, 74.0000],
         [ 0.2500, 72.0000],
         [ 0.5000, 71.0000],
         [ 1.0000, 72.0000],
         [ 0.6500, 72.0000]], dtype=torch.float64),
 tensor([[ 0.6000, 75.0000],
         [ 0.5000, 75.0000],
         [ 0.5000, 72.0000],
         [ 0.5000, 71.0000],
         [ 1.0000, 72.0000],
         [ 0.9000, 72.0000]], dtype=torch.float64),
 tensor([[ 0.5000, 70.0000],
         [ 0.6000, 75.0000],
         [ 0.5000, 75.0000],
         [ 0.5000, 74.0000],
         

In [None]:
cfg = load_config_from_launchplan("sms/exp1/runs/run_20240926_162652/original_launchplan.yaml")
encoder = build_encoder(cfg.model_dump())
encoder.load_state_dict(torch.load("sms/exp1/runs/run_20240926_162652/pretrain_saved_model.pth"))

embeddings_dict = create_embedding_dicts(data_dict, cfg.model_dump(), encoder)



  encoder.load_state_dict(torch.load("sms/exp1/runs/run_20240926_162652/pretrain_saved_model.pth"))


https://python.langchain.com/docs/integrations/vectorstores/faiss/#similarity-search-with-filtering

In [31]:
from pydantic import BaseModel
from sms.exp1.config_classes import LaunchPlanConfig

class ModelEvalConfig(BaseModel):
    name: str
    lp_config: LaunchPlanConfig
    mod_path: str
    path_type: str    #'full' or 'encoder'
    use_full_model: bool

def run_evaluation(
    data_dict: Dict[str, np.ndarray],
    num_loops: int,
    model_configs: List[ModelEvalConfig]
    ) -> Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]:

    # generate random augmentations
    anchor_keys = np.random.choice(list(data_dict.keys()), size=num_loops, replace=False)
    augmented_data = create_augmented_data(data_dict, anchor_keys)

    results = {}
    for eval_config in model_configs:
        logger.info(f"Running evaluation for {eval_config.name}")

        dumped_lp_config = eval_config.lp_config.model_dump()
        bm_cfg = {'full_model_path': eval_config.mod_path} if eval_config.path_type == 'full' else {'encoder_path': eval_config.mod_path}

        model = build_model(dumped_lp_config, **bm_cfg, use_full_model=eval_config.use_full_model)
        embeddings_dict = create_embedding_dict(data_dict, dumped_lp_config, model)

        # create augmented embeddings structure
        augmented_embeddings_dict = {}
        for data_id, aug_dict in augmented_data.items():
            augmented_embeddings_dict[data_id] = create_embedding_dict(aug_dict, dumped_lp_config, model)

        dim = list(embeddings_dict.values())[0].shape[0]
        index = embeddings_to_faiss_index(embeddings_dict=embeddings_dict, index_type="IndexFlatL2", index_args=[dim])

        results[eval_config.name] = evaluate_top_k(embeddings_dict, augmented_embeddings_dict, [1, 3, 5, 10, 25, 50, 100], index)
    
    return results



In [None]:
filtered_data_ids = [str(uuid4()) for _ in range(len(filtered_data))]
filtered_data_dict = dict(zip(filtered_data_ids, filtered_data))


In [36]:
conv_eval_cfg = ModelEvalConfig(
    name="conv_encoder",
    lp_config=load_config_from_launchplan("sms/exp1/runs/run_20240926_162652/original_launchplan.yaml"),
    mod_path="sms/exp1/runs/run_20240926_162652/pretrain_saved_model.pth",
    path_type='encoder',
    use_full_model=False
)

configure_logging(console_level=logging.INFO)
results = run_evaluation(filtered_data_dict, 100, [conv_eval_cfg])

[2024-09-29 15:21:04] [DEBUG] Transposing non-rest notes by 12 semitones.
[2024-09-29 15:21:04] [DEBUG] Shifting note at index 1 by -5 semitones.
[2024-09-29 15:21:04] [DEBUG] Scaling duration of note at index 2 by a factor of 0.5.
[2024-09-29 15:21:04] [DEBUG] Elongated last note by 0.5 to maintain total duration.
[2024-09-29 15:21:04] [DEBUG] Deleting notes at indices [3].
[2024-09-29 15:21:04] [DEBUG] Elongated last note by 1.5 to maintain total duration.
[2024-09-29 15:21:04] [DEBUG] Inserting note at index 0 with duration 0.25 and relative pitch -3.
[2024-09-29 15:21:04] [DEBUG] Truncated note 5 by 0.25 to maintain total duration.
[2024-09-29 15:21:04] [DEBUG] Transposing non-rest notes by -5 semitones.
[2024-09-29 15:21:04] [DEBUG] Shifting note at index 3 by 4 semitones.
[2024-09-29 15:21:04] [DEBUG] Scaling duration of note at index 6 by a factor of 2.0.
[2024-09-29 15:21:04] [DEBUG] Truncated note 6 by 0.20000000000000018 to maintain total duration.
[2024-09-29 15:21:04] [DEBU

In [38]:
def print_nested_dict(d, indent=0):
    for key, value in d.items():
        print('  ' * indent + str(key) + ':', end='')
        if isinstance(value, dict):
            print()
            print_nested_dict(value, indent+1)
        else:
            if isinstance(value, float):
                print(f" {value:.4f}")
            else:
                print(f" {value}")

# Usage
print_nested_dict(results)

conv_encoder:
  chunk_transposed:
    1:
      precision: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
      recall: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
      avg_precision: 0.0000
      avg_recall: 0.0000
    3:
      precision: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0