In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd

# Filepath to embeddings
fname = "/mnt/mimic/data/HAIM/mimic_extras/embeddings.csv"

In [2]:
from dataclasses import dataclass, field
from typing import Optional


#additional packages left to install
from datasets import load_dataset
from functools import partial
from peft import LoraConfig, TaskType, get_peft_model, get_peft_config

In [3]:
quantization_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_compute_dtype=torch.float16,bnb_4bit_quant_type="nf4")

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", device_map="auto", quantization_config=quantization_config,attn_implementation="sdpa")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    bias="none",
    task_type="CAUSAL_LM",
    lora_alpha=16,
    lora_dropout=0.1
)

In [5]:
model = get_peft_model(model, lora_config)

model.print_trainable_parameters()

trainable params: 9,805,824 || all params: 2,515,978,240 || trainable%: 0.3897420034920493


In [None]:

class MLPModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(MLPModel, self).__init__()
        self.fc1 = nn.Linear(input_size, 512)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

In [None]:
df = pd.read_csv(fname)
df_death_small48 = df[((df['img_length_of_stay'] < 48) & (df['death_status'] == 1))]
df_alive_big48 = df[((df['img_length_of_stay'] >= 48) & (df['death_status'] == 0))]
df_death_big48 = df[((df['img_length_of_stay'] >= 48) & (df['death_status'] == 1))]

df_death_small48['y'] = 1
df_alive_big48['y'] = 0
df_death_big48['y'] = 0
df = pd.concat([df_death_small48, df_alive_big48, df_death_big48], axis = 0)

In [None]:
vd_cols = df.filter(regex='^vd_')
y_col = df[['y']]
haim_col = df[['haim_id']]
df = pd.concat([haim_col, vd_cols, y_col], axis=1)
print(df.head())

In [None]:
input_embeddings = torch.tensor(df.iloc[:, 1:1025].values, dtype=torch.float32)
print(input_embeddings[0])
print(input_embeddings)

In [None]:
from tqdm import tqdm

projection_model = MLPModel(input_size=1024, output_size=250).cuda()

result_embeddings = []

for emb in tqdm(input_embeddings, desc="Processing embeddings", unit="embeddings"):
    emb = emb.cuda()
    output_tokens = projection_model(emb)
    normalized_output = torch.sigmoid(output_tokens)
    scaled_output = (normalized_output * 350) + 255649

    rounded_output = torch.round(scaled_output)
    result_embeddings.append(rounded_output)

In [None]:
transformed_embeddings = torch.stack(result_embeddings, dim=0)


print(transformed_embeddings)

In [None]:
def formatting_func(example):
    text = f"### USER: {example['data'][0]}\n### ASSISTANT: {example['data'][1]}"
    return text

In [4]:
layer_names = model.state_dict().keys()

for name in layer_names:
    print(name)

model.embed_tokens.weight
model.layers.0.self_attn.q_proj.weight
model.layers.0.self_attn.q_proj.weight.absmax
model.layers.0.self_attn.q_proj.weight.quant_map
model.layers.0.self_attn.q_proj.weight.quant_state.bitsandbytes__nf4
model.layers.0.self_attn.k_proj.weight
model.layers.0.self_attn.k_proj.weight.absmax
model.layers.0.self_attn.k_proj.weight.quant_map
model.layers.0.self_attn.k_proj.weight.quant_state.bitsandbytes__nf4
model.layers.0.self_attn.v_proj.weight
model.layers.0.self_attn.v_proj.weight.absmax
model.layers.0.self_attn.v_proj.weight.quant_map
model.layers.0.self_attn.v_proj.weight.quant_state.bitsandbytes__nf4
model.layers.0.self_attn.o_proj.weight
model.layers.0.self_attn.o_proj.weight.absmax
model.layers.0.self_attn.o_proj.weight.quant_map
model.layers.0.self_attn.o_proj.weight.quant_state.bitsandbytes__nf4
model.layers.0.mlp.gate_proj.weight
model.layers.0.mlp.gate_proj.weight.absmax
model.layers.0.mlp.gate_proj.weight.quant_map
model.layers.0.mlp.gate_proj.weight.q