Fine-tuning of gemma-2b-it

In [1]:
# ALL THE NECESSARY IMPORTS

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from tqdm import tqdm

from dataclasses import dataclass, field
from typing import Optional
from sklearn.model_selection import train_test_split


from datasets import load_dataset
from functools import partial
from peft import LoraConfig, TaskType, get_peft_model, get_peft_config

# Filepath to embeddings
fname = "/mnt/mimic/data/HAIM/mimic_extras/embeddings.csv"

Setting up the model

Different versions, with huggingface LoRA-class or custom Adapter-module.

In [2]:
# LoRA parameter efficient fine-tuning
# Parameters are freezed and small submodules with low-rank matrices ar inserted at the target layers.
# initialization of model
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
tokenizer.pad_token_id = tokenizer.eos_token_id
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", device_map="auto", quantization_config=quantization_config,attn_implementation="sdpa")
lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    bias="none",
    task_type="CAUSAL_LM",
    lora_alpha=16,
    lora_dropout=0.1
)

gemma = get_peft_model(gemma, lora_config)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
# Model-structure and trainable parameters (this can be tuned by hyperparameters)
gemma.print_trainable_parameters()
gemma

trainable params: 9,805,824 || all params: 2,515,978,240 || trainable%: 0.3897420034920493


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): GemmaForCausalLM(
      (model): GemmaModel(
        (embed_tokens): Embedding(256000, 2048, padding_idx=0)
        (layers): ModuleList(
          (0-17): 18 x GemmaDecoderLayer(
            (self_attn): GemmaSdpaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=2048, out_features=2048, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2048, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2048, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): lora.Linear4bit(
                (base_la

In [2]:
# Adapter NN for parameter efficient fine-tuning
# Adapters (bottleneck feed-forward networks) are added as modules to the layers of the model
# adapting attention projections and MLP projections while freezing original model parameters

class Adapter(nn.Module):
    def __init__(self, size = 6, model_dim = 2048):
        super().__init__()
        self.adapter_block = nn.Sequential(
            nn.Linear(model_dim, size),
            nn.ReLU(),
            nn.Linear(size, model_dim)
        )

    def forward(self, x):

        output = self.adapter_block(x)
        adapter_out = output + x

        return adapter_out


class Adaptered(nn.Module):
    def __init__(self, orig_layer):
        super().__init__()
        self.orig_layer = orig_layer
        self.adapter = Adapter()

    def forward(self, *x):
        orig_out = self.orig_layer(*x)
        output = (self.adapter.forward(orig_out[0].unsqueeze(0))[0],)

        return output



class model_with_adapter(nn.Module):

    def __init__(self):
        super().__init__()
        self.quantization_config = BitsAndBytesConfig(load_in_4bit=True)
        self.model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", device_map="auto", quantization_config=self.quantization_config,attn_implementation="sdpa")
        # Freeze the original model parameters
        for params in self.model.parameters():
            params.requires_grad = False
        # Embed adapter layers into the transformer blocks 
        for i, gemma_layer in enumerate(self.model.model.layers):
            gemma_layer.self_attn.q_proj = Adaptered(gemma_layer.self_attn.q_proj)
            gemma_layer.self_attn.k_proj = Adaptered(gemma_layer.self_attn.k_proj)
            gemma_layer.self_attn.v_proj = Adaptered(gemma_layer.self_attn.v_proj)
            gemma_layer.self_attn.o_proj = Adaptered(gemma_layer.self_attn.o_proj)
    
            gemma_layer.mlp.gate_proj = Adaptered(gemma_layer.mlp.gate_proj)
            gemma_layer.mlp.up_proj = Adaptered(gemma_layer.mlp.up_proj)
            gemma_layer.mlp.down_proj = Adaptered(gemma_layer.mlp.down_proj)

    def get_model(self):

        return self.model



In [3]:
# Custom get_parameters function
def get_parameters(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    trainable_percentage = (trainable_params / total_params) * 100

    trainable_params_str = "{:,}".format(trainable_params)
    total_params_str = "{:,}".format(total_params)

    print(f"trainable params: {trainable_params_str} || all params: {total_params_str} || trainable%: {trainable_percentage}")

In [4]:
# Initialization of adapter-model.
# 
gemma = model_with_adapter().to('cuda')

# Model-structure and trainable parameters (this can be tuned by hyperparameters)
get_parameters(gemma)
gemma

# OBS A lot less params, not sure why.. (maybe cause of degeneration? or mistake, need to look into)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 3,355,380 || all params: 1,518,623,476 || trainable%: 0.2209487771674616


model_with_adapter(
  (model): GemmaForCausalLM(
    (model): GemmaModel(
      (embed_tokens): Embedding(256000, 2048, padding_idx=0)
      (layers): ModuleList(
        (0-17): 18 x GemmaDecoderLayer(
          (self_attn): GemmaSdpaAttention(
            (q_proj): Adaptered(
              (orig_layer): Linear4bit(in_features=2048, out_features=2048, bias=False)
              (adapter): Adapter(
                (adapter_block): Sequential(
                  (0): Linear(in_features=2048, out_features=6, bias=True)
                  (1): ReLU()
                  (2): Linear(in_features=6, out_features=2048, bias=True)
                )
              )
            )
            (k_proj): Adaptered(
              (orig_layer): Linear4bit(in_features=2048, out_features=256, bias=False)
              (adapter): Adapter(
                (adapter_block): Sequential(
                  (0): Linear(in_features=2048, out_features=6, bias=True)
                  (1): ReLU()
                  (2):

Projection module

In [3]:
embedding_size = 1024
projection_size = 6

class ProjectionNN(nn.Module):
    def __init__(self):
        super(ProjectionNN, self).__init__()

        # Architecture
        self.fc1 = nn.Linear(embedding_size, 128).cuda()
        self.relu = nn.ReLU().cuda()
        self.fc2 = nn.Linear(128, 2048 * projection_size).cuda()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = x.view(-1,6,2048)
        return x

Fetching and preprocessing of data

In [4]:
df = pd.read_csv(fname)
df_death_small48 = df[((df['img_length_of_stay'] < 48) & (df['death_status'] == 1))]
df_alive_big48 = df[((df['img_length_of_stay'] >= 48) & (df['death_status'] == 0))]
df_death_big48 = df[((df['img_length_of_stay'] >= 48) & (df['death_status'] == 1))]

df_death_small48['y'] = 1
df_alive_big48['y'] = 0
df_death_big48['y'] = 0
df = pd.concat([df_death_small48, df_alive_big48, df_death_big48], axis = 0)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_death_small48['y'] = 1
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_alive_big48['y'] = 0
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_death_big48['y'] = 0


In [5]:
vd_cols = df.filter(regex='^vd_')
y_col = df[['y']]
haim_col = df[['haim_id']]
df = pd.concat([haim_col, vd_cols, y_col], axis=1)

pkl_list = df['haim_id'].unique().tolist()

print(df.head())

     haim_id      vd_0      vd_1      vd_2      vd_3      vd_4      vd_5  \
256     6557  0.005299  0.082119  0.274407  0.017487  0.255308  0.003707   
259     6557  0.000000  0.079306  0.381579  0.015250  0.402685  0.011122   
267     6558  0.005299  0.082119  0.274407  0.017487  0.255308  0.003707   
270     6558  0.000000  0.079306  0.381579  0.015250  0.402685  0.011122   
319     6581  0.002288  0.078941  0.088397  0.017775  0.071482  0.006970   

         vd_6      vd_7      vd_8  ...   vd_1015   vd_1016   vd_1017  \
256  0.137267  0.024046  0.145395  ...  0.008003  0.013876  0.005360   
259  0.125938  0.033254  0.227433  ...  0.042140  0.036560  0.006585   
267  0.137267  0.024046  0.145395  ...  0.008003  0.013876  0.005360   
270  0.125938  0.033254  0.227433  ...  0.042140  0.036560  0.006585   
319  0.223354  0.045017  0.056177  ...  0.004973  0.000343  0.000000   

      vd_1018   vd_1019   vd_1020   vd_1021   vd_1022   vd_1023  y  
256  0.039292  0.029467  0.003972  0.0002

In [6]:

def data_split(df, pkl_list, test_size=0.3, validation_size=0.1, random_state=None):
    # Split into training and test sets
    train_set, test_set = train_test_split(pkl_list, test_size=test_size, random_state=random_state)

    # Further split the training set into training and validation sets
    train_set, validation_set = train_test_split(train_set, test_size=validation_size, random_state=random_state)

    train_idx = df[df['haim_id'].isin(train_set)]['haim_id'].tolist()
    validation_idx = df[df['haim_id'].isin(validation_set)]['haim_id'].tolist()
    test_idx = df[df['haim_id'].isin(test_set)]['haim_id'].tolist()

    x_train = df[df['haim_id'].isin(train_idx)].drop(['haim_id', 'y'], axis=1).values
    x_validation = df[df['haim_id'].isin(validation_idx)].drop(['haim_id', 'y'], axis=1).values
    x_test = df[df['haim_id'].isin(test_idx)].drop(['haim_id', 'y'], axis=1).values

    y_train = df[df['haim_id'].isin(train_idx)]['y'].values
    y_validation = df[df['haim_id'].isin(validation_idx)]['y'].values
    y_test = df[df['haim_id'].isin(test_idx)]['y'].values

    return x_train, x_validation, x_test, y_train, y_validation, y_test

In [None]:
def data_split(df, pkl_list, test_size=0.3, validation_size=0.1, random_state=None):

    types = ['vd_', 'vp_', 'vmd_', 'ts_ce_', 'ts_le_', 'ts_pe_', 'n_rad_']

    x_train = {}
    x_validation = {}
    x_test = {}

    # Split into training and test sets
    train_set, test_set = train_test_split(pkl_list, test_size=test_size, random_state=random_state)

    remaining_data_size = 1.0 - test_size

    validation_size = validation_size*remaining_data_size

    # Further split the training set into training and validation sets
    train_set, validation_set = train_test_split(train_set, test_size=validation_size, random_state=random_state)

    train_idx = df[df['haim_id'].isin(train_set)]['haim_id'].tolist()
    validation_idx = df[df['haim_id'].isin(validation_set)]['haim_id'].tolist()
    test_idx = df[df['haim_id'].isin(test_set)]['haim_id'].tolist()

    for t in types:
        
        x_train[t] = df[df['haim_id'].isin(train_idx)].filter(regex='^'+t).values
        x_validation[t] = df[df['haim_id'].isin(validation_idx)].filter(regex='^'+t).values
        x_test[t] = df[df['haim_id'].isin(test_idx)].filter(regex='^'+t).values

    y_train = df[df['haim_id'].isin(train_idx)]['y'].values
    y_train = [' ###ANSWER: No' if value == 0 else ' ###ANSWER: Yes' for value in y_train]
    y_validation = df[df['haim_id'].isin(validation_idx)]['y'].values
    y_validation = [' ###ANSWER: No' if value == 0 else ' ###ANSWER: Yes' for value in y_validation]
    y_test = df[df['haim_id'].isin(test_idx)]['y'].values
    y_test = [' ###ANSWER: No' if value == 0 else ' ###ANSWER: Yes' for value in y_test]

    return x_train, x_validation, x_test, y_train, y_validation, y_test

In [7]:
x_train, x_val, x_test, y_train, y_val, y_test, = data_split(df, pkl_list)

In [None]:
class DataSplit():

    def __init__(self, df):
        self.df = df
        self.types = ['vd_', 'vp_', 'vmd_', 'ts_ce_', 'ts_le_', 'ts_pe_', 'n_rad_']
        self.partition = None

    def partitiondata(self, partition):
        self.pkl_list = []
        if partition == 'mortality':
            condition_death_small48 = (self.df['img_length_of_stay'] < 48) & (self.df['death_status'] == 1)

            y = [0]*len(self.df)
            for i, condition in enumerate(condition_death_small48):
                if condition:
                    y[i] = 1

        if partition == 'los':
            condition_alive_small48 = self.df[((self.df['img_length_of_stay'] < 48) & (self.df['death_status'] == 0))]

            y = [0]*len(self.df)
            for i, condition in enumerate(condition_alive_small48):
                if condition:
                    y[i] = 1

        self.df['y'] = y

    def get_data(self, partition, test_size=0.3, validation_size=0.1, random_state=None):

        self.partition = partition

        self.partitiondata(partition)
        pkl_list = self.df['haim_id'].unique().tolist()

        # Split into training and test sets
        train_set, test_set = train_test_split(pkl_list, test_size=test_size, random_state=random_state)

        remaining_data_size = 1.0 - test_size
        validation_size = validation_size*remaining_data_size

        # Further split the training set into training and validation sets
        train_set, validation_set = train_test_split(train_set, test_size=validation_size, random_state=random_state)

        train_idx = self.df[self.df['haim_id'].isin(train_set)]['haim_id'].tolist()
        validation_idx = self.df[self.df['haim_id'].isin(validation_set)]['haim_id'].tolist()
        test_idx = self.df[self.df['haim_id'].isin(test_set)]['haim_id'].tolist()

        self.x_train = {t: self.df[self.df['haim_id'].isin(train_idx)].filter(regex='^'+t).values for t in self.types}
        self.x_validation = {t: self.df[self.df['haim_id'].isin(validation_idx)].filter(regex='^'+t).values for t in self.types}
        self.x_test = {t: self.df[self.df['haim_id'].isin(test_idx)].filter(regex='^'+t).values for t in self.types}

        self.y_train = df[df['haim_id'].isin(train_idx)]['y'].values
        self.y_validation = df[df['haim_id'].isin(validation_idx)]['y'].values
        self.y_test = df[df['haim_id'].isin(test_idx)]['y'].values

Creating Dataset and DataLoader

In [9]:
class CustomDataset(Dataset):
    def __init__(self, embedding, labels):
        self.labels = labels
        self.embedding = embedding

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        label = self.labels[idx]
        emb = self.embedding[idx]
        sample = {"Emb": emb, "Class": label}
        return sample

In [12]:
# collate_batch for 0/1 labels

def collate_batch(batch):
     
    emb_list, classes = [], []
    for thing in batch:
        #print(batch)
        emb_list.append(thing['Emb'])
        classes.append(thing['Class'])
    text = torch.tensor(emb_list)
    classes = torch.tensor(classes, dtype=torch.float16)
    return text, classes

In [8]:
model = ProjectionNN()

Trainset = CustomDataset(x_train.tolist(), y_train)
Valset = CustomDataset(x_val, y_gen_val)
Testset = CustomDataset(x_test, y_gen_test)

Train_loader = DataLoader(Trainset, batch_size=1, collate_fn=collate_batch)
Val_loader = DataLoader(Valset, batch_size=1, collate_fn=collate_batch)
Test_loader = DataLoader(Testset, batch_size=1, collate_fn=collate_batch)

NameError: name 'CustomDataset' is not defined

In [11]:
model = ProjectionNN()

Trainset = CustomDataset(x_train.tolist(), y_gen_train)
Valset = CustomDataset(x_val, y_val)
Testset = CustomDataset(x_test, y_test)

Train_loader = DataLoader(Trainset, batch_size=1, collate_fn=collate_batch)
Val_loader = DataLoader(Valset, batch_size=1, collate_fn=collate_batch)
Test_loader = DataLoader(Testset, batch_size=1, collate_fn=collate_batch)

Training functions

In [30]:
def custom_output(emb, gemma):
    outputs = gemma(inputs_embeds=emb)
    noyes = [1294, 3276]
    logits = outputs['logits']
    logits = logits[:,-1,noyes]
    return logits

def output_to_label(logits):
    probs = torch.softmax(logits, dim=-1)
    predicted_token_id = torch.argmax(probs, dim=-1)
    return predicted_token_id

    
def train_epoch(model, gemma, optimizer, loss_fn, train_loader, device):
    # Train:
    model.train()
    train_loss_batches, train_acc_batches = [], []
    num_batches = len(train_loader)
    for batch_index, (x, y) in enumerate(train_loader, 1):
        inputs, labels = x.to(device), y.to(device)
        optimizer.zero_grad()

        emb = model(inputs)
        logits = custom_output(emb.to(torch.float16), gemma)
        
        loss = loss_fn(logits, labels.float())
        loss.backward()
        optimizer.step()
        train_loss_batches.append(loss.item())

        hard_preds = output_to_label(logits)
        acc_batch_avg = (hard_preds == labels).float().mean().item()
        train_acc_batches.append(acc_batch_avg)

    return model, train_loss_batches, train_acc_batches

def validate(model, gemma, loss_fn, val_loader, device, word_embs):
    val_loss_cum = 0
    val_acc_cum = 0
    model.eval()
    with torch.no_grad():
        for batch_index, (x, y) in enumerate(val_loader, 1):
            inputs, labels = x.to(device), y.to(device)
            emb = model.forward(inputs)
            concatted = torch.cat((word_embs,emb), dim=1).to(torch.float16)
            logits = custom_output(concatted, gemma)

            batch_loss = loss_fn(logits, labels.float())
            val_loss_cum += batch_loss.item()
            hard_preds = output_to_label(logits)
            acc_batch_avg = (hard_preds == labels).float().mean().item()
            val_acc_cum += acc_batch_avg
    return val_loss_cum/len(val_loader), val_acc_cum/len(val_loader)

In [31]:
def training_loop(model, gemma, optimizer, loss_fn, train_loader, val_loader, num_epochs):
    print("Starting training")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    train_losses, train_accs, val_losses, val_accs = [], [], [], []

    for epoch in range(1, num_epochs+1):
        model, train_loss, train_acc = train_epoch(model, gemma,
                                                   optimizer,
                                                   loss_fn,
                                                   train_loader,
                                                   device)
        val_loss, val_acc = validate(model, loss_fn, val_loader, device)
        print(f"Epoch {epoch}/{num_epochs}: "
              f"Train loss: {sum(train_loss)/len(train_loss):.3f}, "
              f"Train acc.: {sum(train_acc)/len(train_acc):.3f}, "
              f"Val. loss: {val_loss:.3f}, "
              f"Val. acc.: {val_acc:.3f}")
        train_losses.extend(train_loss)
        train_accs.extend(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
    return model, train_losses, train_accs, val_losses, val_accs

In [41]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.CrossEntropyLoss()
training_loop(model, gemma, optimizer, loss_fn, Train_loader, Val_loader, 2)

Starting training
cuda


KeyboardInterrupt: 

Experiments

In [67]:
for batch_index, (x, y, z) in enumerate(Train_loader, 1):
    print(f"Batch {batch_index} - Inputs: {x}, Labels: {y}, Instructions {z}")

    if batch_index > 3:
        break  # Break the loop after printing the first three batches

Batch 1 - Inputs: tensor([[0.0023, 0.0789, 0.0884,  ..., 0.0000, 0.0134, 0.1688]]), Labels: tensor([[     2,   6176, 221049, 235292,  11396,    736,   7679,   4692,   5543,
           1178, 235248, 235310, 235321,    531, 235336,  43774,  11189,  44702,
         235292, 235248]]), Instructions tensor([[     2,  43774,  72601, 235292,   6287]])
Batch 2 - Inputs: tensor([[0.0104, 0.0867, 0.1055,  ..., 0.0014, 0.0057, 0.0628]]), Labels: tensor([[     2,   6176, 221049, 235292,  11396,    736,   7679,   4692,   5543,
           1178, 235248, 235310, 235321,    531, 235336,  43774,  11189,  44702,
         235292, 235248]]), Instructions tensor([[     2,  43774,  72601, 235292,   6287]])
Batch 3 - Inputs: tensor([[0.0067, 0.0579, 0.1242,  ..., 0.0031, 0.0000, 0.0049]]), Labels: tensor([[     2,   6176, 221049, 235292,  11396,    736,   7679,   4692,   5543,
           1178, 235248, 235310, 235321,    531, 235336,  43774,  11189,  44702,
         235292, 235248]]), Instructions tensor([[    

In [9]:
# dataset for generative

class CustomDataset(Dataset):
    def __init__(self, embedding, labels):
        self.labels = labels
        self.embedding = embedding
        self.instruction = "###INSTRUCTION: Did this patient stay longer than 48 h? ###MODALITY: "

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        label = self.labels[idx]
        emb = self.embedding[idx]
        inst = self.instruction
        sample = {"Emb": emb, "Class": label, 'Inst': inst}
        #print(sample)
        return sample

In [10]:
# collate_batch for generative

def collate_batch(batch):
     
    emb_list, classes, instructions = [], [], []
    for thing in batch:
        emb_list.append(thing['Emb'])
        classes.append(tokenizer(thing['Class'], return_tensors="pt"))
        instructions.append(tokenizer(thing['Inst'], return_tensors="pt"))
    text = torch.tensor(emb_list)
    classes = torch.cat([item['input_ids'] for item in classes], dim=0)
    instructions = torch.cat([item['input_ids'] for item in instructions], dim=0)
    return text, instructions, classes

In [12]:
# FIRST APPROACH GEN TRYOUT

def train_epoch(model, gemma, optimizer, loss_fn, train_loader, device):
    # Train:
    model.train()
    train_loss_batches, train_acc_batches = [], []
    num_batches = len(train_loader)
    embedding_matrix = gemma.get_input_embeddings().weight
    for batch_index, (mod, inst, label) in enumerate(train_loader, 1):
        mod_embeddings = model(mod.to(device))
        print(mod_embeddings.device)
        inst_list = [embedding_matrix[token_id].to(device) for token_id in inst.to(dtype=torch.long)]
        label_list = [embedding_matrix[token_id].to(device) for token_id in label.to(dtype=torch.long)]
        inst_embeddings = torch.stack(inst_list)
        label_embeddings = torch.stack(label_list)
        print(label)
        
        mod_embeddings = mod_embeddings.to(device)
        #print(mod_embeddings.shape)
        #print(inst_embeddings.shape)
        #print(label_embeddings.shape)

        optimizer.zero_grad()

        conc_emb = torch.cat([inst_embeddings.to(dtype=torch.float16).to(device), mod_embeddings.to(device), label_embeddings.to(dtype=torch.float16).to(device)], dim=1).to(device)
        target = label.to(dtype=torch.float16)
        print(target.shape)
        print(conc_emb.shape)
        output = gemma(inputs_embeds=conc_emb.to(dtype=torch.float16))
        print('forward: ', output['logits'].shape)
        probabilities = F.softmax(output['logits'], dim=-1)
        print('probabilities: ', probabilities.shape)
        predicted_ids = torch.argmax(probabilities, dim=-1)
        print('predicted ids: ', predicted_ids.shape)

        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        train_loss_batches.append(loss.item())

        hard_preds = output_to_label(logits)
        acc_batch_avg = (hard_preds == labels).float().mean().item()
        train_acc_batches.append(acc_batch_avg)

    return model, train_loss_batches, train_acc_batches

def validate(model, gemma, loss_fn, val_loader, device, word_embs):
    val_loss_cum = 0
    val_acc_cum = 0
    model.eval()
    with torch.no_grad():
        for batch_index, (x, y) in enumerate(val_loader, 1):
            inputs, labels = x.to(device), y.to(device)
            emb = model.forward(inputs)
            concatted = torch.cat((word_embs,emb), dim=1).to(torch.float16)
            logits = custom_output(concatted, gemma)

            batch_loss = loss_fn(logits, labels.float())
            val_loss_cum += batch_loss.item()
            hard_preds = output_to_label(logits)
            acc_batch_avg = (hard_preds == labels).float().mean().item()
            val_acc_cum += acc_batch_avg
    return val_loss_cum/len(val_loader), val_acc_cum/len(val_loader)

In [None]:
def left_padding(concatenated_emb, label, device):
    prompt_length = concatenated_emb.size(1) - label[0].size(0)
    padded_labels= torch.cat([torch.full((prompt_length,), -100).to(device), label[0].to(device)]).to(device)
    return padded_labels

In [None]:
def masking(logits, padded_target):
    print(padded_target)
    mask = (padded_target != -100)
    logits_masked = logits[:, mask, :]
    targets_masked = padded_target[mask].view(-1)
    return logits_masked, targets_masked

In [40]:
# SECOND APPROACH GEN TRYOUT

def output_to_label(logits):
    probs = torch.softmax(logits, dim=-1)
    predicted_token_id = torch.argmax(probs, dim=-1)
    return predicted_token_id

def train_epoch(model, gemma, optimizer, loss_fn, train_loader, device):
    # Train:
    gemma.train()
    train_loss_batches, train_acc_batches = [], []
    num_batches = len(train_loader)
    embedding_matrix = gemma.get_input_embeddings().weight
    for batch_index, (mod, inst, label) in enumerate(train_loader, 1):
        mod_embeddings = model(mod.to(device))
        inst_list = [embedding_matrix[token_id].to(device) for token_id in inst.to(dtype=torch.long)]
        label_list = [embedding_matrix[token_id].to(device) for token_id in label.to(dtype=torch.long)]
        inst_embeddings = torch.stack(inst_list)
        label_embeddings = torch.stack(label_list)
        
        mod_embeddings = mod_embeddings.to(device)

        optimizer.zero_grad()

        conc_emb = torch.cat([inst_embeddings.to(dtype=torch.float16).to(device), mod_embeddings.to(device), label_embeddings.to(dtype=torch.float16).to(device)], dim=1).to(device)
        prompt_length = conc_emb.size(1) - label[0].size(0)

        padding_tensor = torch.cat([torch.full((prompt_length,), -100).to('cuda'), label[0].to(device)]).to(device)


        output = gemma(inputs_embeds=conc_emb.to(dtype=torch.float16), labels=padding_tensor)

        logits = output['logits'] #torch.Size([1, 31, 256000])
        mask = (padding_tensor != -100) #torch.Size([31])
        logits_masked = logits[:, mask, :]

        targets_masked = padding_tensor[mask].view(-1)

        loss = loss_fn(logits_masked.squeeze(0), targets_masked)
        loss.backward()
        optimizer.step()
        train_loss_batches.append(loss.item())

        hard_preds = output_to_label(logits_masked)
        acc_batch_avg = (hard_preds == targets_masked).float().mean().item()
        train_acc_batches.append(acc_batch_avg)

    return model, train_loss_batches, train_acc_batches

def validate(model, gemma, loss_fn, val_loader, device, word_embs):
    val_loss_cum = 0
    val_acc_cum = 0
    gemma.eval()
    embedding_matrix = gemma.get_input_embeddings().weight
    with torch.no_grad():
        for batch_index, (mod, inst, label) in enumerate(train_loader, 1):
            mod_embeddings = model(mod.to(device))
            inst_list = [embedding_matrix[token_id].to(device) for token_id in inst.to(dtype=torch.long)]
            label_list = [embedding_matrix[token_id].to(device) for token_id in label.to(dtype=torch.long)]
            inst_embeddings = torch.stack(inst_list)
            label_embeddings = torch.stack(label_list)
        
            mod_embeddings = mod_embeddings.to(device)

            optimizer.zero_grad()

            conc_emb = torch.cat([inst_embeddings.to(dtype=torch.float16).to(device), mod_embeddings.to(device), label_embeddings.to(dtype=torch.float16).to(device)], dim=1).to(device)
            prompt_length = conc_emb.size(1) - label[0].size(0)

            padding_tensor = torch.cat([torch.full((prompt_length,), -100).to('cuda'), label[0].to(device)]).to(device)


            output = gemma(inputs_embeds=conc_emb.to(dtype=torch.float16), labels=padding_tensor)

            logits = output['logits'] #torch.Size([1, 31, 256000])
            mask = (padding_tensor != -100) #torch.Size([31])
            logits_masked = logits[:, mask, :]

            targets_masked = padding_tensor[mask].view(-1)


            loss = loss_fn(logits_masked.squeeze(0), targets_masked)
            val_loss_cum += loss.item()
            hard_preds = output_to_label(logits_masked)
            acc_batch_avg = (hard_preds == targets_masked).float().mean().item()
            val_acc_cum += acc_batch_avg
    return val_loss_cum/len(val_loader), val_acc_cum/len(val_loader)

In [14]:
def training_loop(model, gemma, optimizer, loss_fn, train_loader, val_loader, num_epochs):
    print("Starting training")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    model.to(device)
    gemma.to(device)
    #for name, param in gemma.named_parameters():
    #    print(f"{name}: {param.device}")
    train_losses, train_accs, val_losses, val_accs = [], [], [], []

    for epoch in range(1, num_epochs+1):
        model, train_loss, train_acc = train_epoch(model, gemma,
                                                   optimizer,
                                                   loss_fn,
                                                   train_loader,
                                                   device)
        val_loss, val_acc = validate(model, loss_fn, val_loader, device)
        print(f"Epoch {epoch}/{num_epochs}: "
              f"Train loss: {sum(train_loss)/len(train_loss):.3f}, "
              f"Train acc.: {sum(train_acc)/len(train_acc):.3f}, "
              f"Val. loss: {val_loss:.3f}, "
              f"Val. acc.: {val_acc:.3f}")
        train_losses.extend(train_loss)
        train_accs.extend(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
    return model, train_losses, train_accs, val_losses, val_accs

In [19]:
nos = ['No', 'no', ' No', ' no', 'no ', 'No ']
yes = ['Yes', 'yes', ' Yes', ' yes', 'yes ', 'Yes ']
embedding_matrix = gemma.get_input_embeddings().weight
print(len(embedding_matrix))


no_labels = []
yes_labels = []
for n in tqdm(range(256000)):
    t = tokenizer.decode(n)
    if t in yes:
        yes_labels.append(n)
    if t in nos:
        no_labels.append(n)

print('No: ', no_labels)
print('Yes: ', yes_labels)

256000


100%|██████████| 256000/256000 [00:01<00:00, 178866.40it/s]

No:  [793, 956, 1294, 1307]
Yes:  [3276, 3553, 6287, 7778]





In [116]:
# Trying out finetune of generative aspect of gemma

embedding_matrix = gemma.get_input_embeddings().weight
prompt = "### INSTRUCTION: Did this patient stay longer than 48 h?"
lab = '### ANSWER: Yes'
mod = torch.randn(1, 1024).to('cuda')
prompt_ids = tokenizer(prompt, return_tensors="pt").to("cuda")
lab_ids = tokenizer(lab, return_tensors="pt").to("cuda")
print(len(lab_ids))
#print(prompt_ids["input_ids"].shape)
prompt_embeddings = [embedding_matrix[token_id].to('cuda') for token_id in prompt_ids["input_ids"][0]]
#print(lab.shape)
prompt_embeddings = torch.stack(prompt_embeddings)
label_embeddings = [embedding_matrix[token_id].to('cuda') for token_id in lab_ids["input_ids"][0]]
label_embeddings = torch.stack(label_embeddings)

p_mod = model(mod)

conc_emb = torch.cat([prompt_embeddings.unsqueeze(0).to(dtype=torch.float16), p_mod.to('cuda').to(dtype=torch.float16), label_embeddings.unsqueeze(0).to(dtype=torch.float16)], dim=1)

prompt_length = conc_emb.size(1) - lab_ids['input_ids'][0].size(0)
print(prompt_length)

padding_tensor = torch.cat([torch.full((prompt_length,), -100).to('cuda'), lab_ids['input_ids'][0]])

#print(padding_tensor)
#print(conc_emb.shape)
conc_emb.to(dtype=torch.float16).to('cuda')
output = gemma(inputs_embeds=conc_emb, labels=padding_tensor)
#print(output)

probabilities = F.softmax(output['logits'], dim=-1)
predicted_ids = torch.argmax(probabilities, dim=-1)
print(predicted_ids)

tokenizer.decode(predicted_ids[0])
#tokenizer.decode(0)

2
21
tensor([[229711,    714, 235292,    109,    692,   1707,    791,    575,    575,
           5043, 235304,   3763,   3763, 235336,    109,  61278, 229711, 222975,
           1692, 130983,  61278,    109,    109, 235292,    590, 235269]],
       device='cuda:0')


' increa The:\n\n you help have in in expected3 hours hours?\n\nMathML increa WhencetabularBASELINEMathML\n\n\n\n: I,'