In [None]:
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
import torch
import polars as pl
import numpy as np
import os
import sys
sys.path.append(os.getcwd()+"/../..")
from src import paths
from huggingface_hub import notebook_login
import tqdm
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

In [None]:
# Login to Hugging Face Hub as model is gated
notebook_login()

In [None]:
# Checkpoint
checkpoint = "GerMedBERT/medbert-512"

# Load model
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# Load model for embedding
model = AutoModel.from_pretrained(checkpoint)
model_base = model.base_model

In [None]:
# Load data
df = pl.read_csv(os.path.join(paths.DATA_PATH_PREPROCESSED, "line_labelling_clean.csv"))
df = df.filter(pl.col("text").is_not_null())

df["text"].to_list()
# Tokenize data
tokenized = tokenizer(df["text"].to_list(), padding=True, truncation=True, max_length=512, return_tensors="pt")

# Embed data while logging progress
embeddings = []
batch_size = 32

for i in tqdm.tqdm(range(0, len(tokenized["input_ids"]), batch_size)):
    tokens = tokenized["input_ids"][i:i+batch_size]
    attention_mask = tokenized["attention_mask"][i:i+batch_size]
    embeddings.append(model(tokens, attention_mask).last_hidden_state.detach())
    

In [None]:
test = torch.cat(embeddings, dim=0)
test.shape

# CLS token
cls = test[:,0,:]
cls.shape

In [None]:
# Visualize embeddings with PCA
pca = PCA(n_components=2)
components = pca.fit_transform(cls)

# Sample data
x = components[:,0]
y = components[:,1]
labels = df["class_agg"][:20].to_list()

# Create a dictionary to map labels to unique numeric values
label_to_numeric = {label: i for i, label in enumerate(df["class_agg"].unique())}

# Convert labels to numeric values
numeric_labels = [label_to_numeric[label] for label in labels]

# Create a colormap
cmap = plt.get_cmap("viridis")  # You can choose a different colormap if desired

# Scatter plot with colors based on labels
scatter = plt.scatter(x, y, c=numeric_labels ,cmap=cmap)

# Create a colorbar to display label-color mapping
cbar = plt.colorbar(scatter)
cbar.set_label("Class Label")

plt.xlabel("Component 1")
plt.ylabel("Component 2")
plt.title("Scatter Plot with Colored Data Points")
plt.show()

In [None]:
x.shape

In [None]:
#torch.equal(output["last_hidden_state"], output["hidden_states"][-1])
output.keys()

In [None]:
output2.keys()

In [None]:
# Shape of tokenized data
torch.equal(output.hidden_states[1], output2.hidden_states[1])

In [None]:
for name, param in model_base.named_parameters():
    print(name, param.shape)