# Embeddings Model

Overview plan:
- Compile all attributes, so they don't have dependencies
- Extract embeddings from all attributes
- Generate synthetic seismic data
- Execute the attributes with the synthetic seismic data
- Prepare the features
- Train the model
- Validate

## Compiling the attributes

In [6]:
import ast
import os
import astor


def read_file(file_path):
    with open(file_path, "r") as file:
        return file.read()


def inline_imports(file_path, visited=None, base_path=""):
    if visited is None:
        visited = set()

    if file_path in visited:
        return ""

    visited.add(file_path)

    full_path = os.path.join(base_path, file_path)
    code = read_file(full_path)
    tree = ast.parse(code)
    inlined_code = ""

    for node in tree.body:
        if isinstance(node, (ast.Import, ast.ImportFrom)):
            if isinstance(node, ast.Import):
                module_name = node.names[0].name
            elif isinstance(node, ast.ImportFrom):
                module_name = node.module

            module_path = module_name.replace(".", "/") + ".py"
            # Handle the case where module is a package (i.e., __init__.py)
            if not os.path.exists(os.path.join(base_path, module_path)):
                module_path = os.path.join(module_name.replace(".", "/"), "__init__.py")

            if os.path.exists(os.path.join(base_path, module_path)):
                inlined_code += inline_imports(module_path, visited, base_path)
        else:
            inlined_code += astor.to_source(node)

    return inlined_code


def create_single_script(main_script_path, base_path=""):
    single_file_code = inline_imports(main_script_path, base_path=base_path)
    return single_file_code

main_script_path = "envelope.py"
base_path = "../../../tools/seismic/seismic/attributes"
single_script_code = create_single_script(main_script_path, base_path)
print(single_script_code)

def run(segy_filepath: str, n_workers: int=1, single_threaded: bool=True):
    from dasf_seismic.attributes.complex_trace import Envelope
    from seismic.cluster import run_attribute
    quality = Envelope()
    return run_attribute(quality, segy_filepath, n_workers, single_threaded)



## Extract code embeddings

In [None]:
from transformers import RobertaTokenizer, RobertaModel
import torch

def get_code_embedding(code_snippet):
    tokenizer = RobertaTokenizer.from_pretrained("huggingface/CodeBERTa-small-v1")
    model = RobertaModel.from_pretrained("huggingface/CodeBERTa-small-v1")
    
    inputs = tokenizer(code_snippet, return_tensors="pt")
    outputs = model(**inputs)
    
    cls_embedding = outputs.last_hidden_state[:, 0, :]
    return cls_embedding.detach().numpy()