In [1]:
from triplet_dataset import get_dataset, get_sentence_id_label_df, TripletDataset
from network import get_sts_model
from tqdm.auto import tqdm
import numpy as np
from model_evaluation import  compare_sactter_plots
from sklearn.manifold import TSNE
import torch
import os
import torch.nn as nn
from peft import LoraConfig

RANK = 16
PEFT_CONFIG = LoraConfig(inference_mode=False, 
              r=RANK, 
              lora_alpha=RANK*2, 
              lora_dropout=0.05,
              target_modules=['value','query','key']
              )
global DEVICE
global MODEL_PATH
global MODEL
global TOKENIZER
global DATA
global VIS_DATA

DEVICE = 'cuda'
MODEL_PATH = 'sentence-transformers/all-MiniLM-L12-v1'
MODEL = get_sts_model(model_path=MODEL_PATH, device=DEVICE, pef_config=PEFT_CONFIG)
TOKENIZER = MODEL.tokenizer
DATA = get_dataset()
VIS_DATA = get_sentence_id_label_df()


def gen_vis_embeddings(no_peft=False):
    embeddings = []
    sentences = VIS_DATA['sentence'].tolist()
    for sentence in tqdm(sentences, unit='sentence', desc='Generating embeddings'):
        if no_peft:
            print('Do Something')
        else:
            embedding = MODEL(sentence).detach().cpu().numpy()
        embeddings.append(embedding)
    embeddings = np.array(embeddings).squeeze()
    embeddings = TSNE(n_components=2).fit_transform(embeddings)
    return embeddings


trainable params: 442368 || all params: 33802368 || trainable%: 1.3086893793949583


Creating triplets:   0%|          | 0/14 [00:00<?, ?group/s]

In [4]:
def main():
    # Model without
    # bare_embeddings = gen_vis_embeddings(no_peft=True)
    # ids = VIS_DATA['id'].tolist()
    # print('Bare model Clustering: ')
    # compare_sactter_plots(bare_embeddings, None, ids)
    # print('='*100)
    batch_size = 8
    optimizer = torch.optim.AdamW(params=MODEL.parameters(), lr=1e-5)
    triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
    num_epochs = 5
    eval_every = 100
    save_model_every = 1000
    print('Training model...')
    train(batch_size, optimizer, triplet_loss, num_epochs, eval_every, save_model_every)
    

def train(batch_size,optimizer, triplet_loss, num_epochs, eval_every, save_model_every):
    for epoch in range(num_epochs):
        train_dataset = TripletDataset(DATA, tokenizer=TOKENIZER, device=DEVICE, batch_size=batch_size, shuffle=True, max_len=200)
        loss = 0
        steps = 0
        accumelated_loss = 0
        tbar = tqdm(train_dataset, unit='batch')
        for input in tbar:
                steps += 1
                anchor = MODEL(input[0])
                positive = MODEL(input[1])
                negative = MODEL(input[2])
                loss = triplet_loss(anchor, positive, negative)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                accumelated_loss += loss.item()
                tbar.set_description(f'Epoch {epoch} current loss: {loss:.2f} avaerage loss: {accumelated_loss/steps:.2f}')



                if steps % eval_every == 0:
                    print('Evaluating model')
                    print(f'Avaerage loss: {accumelated_loss/steps:.2f} ')
                    # embeddings = gen_vis_embeddings()
                    # ids = VIS_DATA['id'].tolist()
                    # print('Model Clustering: ')
                    # compare_sactter_plots(embeddings,None, ids)
                    # print('='*100)
                if steps % 200 == 0:
                    accumelated_loss = 0

                if steps % save_model_every == 0:
                    print('Saving model')
                    if not os.path.exists('./models'):
                        os.makedirs('./models')
                    
                    lora_model = MODEL.Bert_representations

                    if not os.path.exists('./models/LoRa'):
                        os.makedirs('./models/LoRa')

                    #model name without / character
                    short_model_name = MODEL_PATH.split('/')[-1]
                    lora_save_path = f'./models/LoRa/lora_model_{short_model_name}_{steps}'
                    lora_model.save_pretrained(lora_save_path)

In [5]:
main()


Training model...


Batching data:   0%|          | 0/7112 [00:00<?, ?row/s]

Tokenizing data:   0%|          | 0/890 [00:00<?, ?batch/s]

  0%|          | 0/890 [00:00<?, ?batch/s]

Evaluating model
Avaerage loss: 0.64 
Evaluating model
Avaerage loss: 0.00 
Evaluating model
Avaerage loss: 0.18 
Evaluating model
Avaerage loss: 0.00 
