In [2]:
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
from train_lstm import instantiate_model, LSTM_AE, QNA_DATA_DIR, BASE_DATA_DIR

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# LSTM Embeddings

In [3]:
class MCQDataset(torch.utils.data.Dataset):
    _nlp_model = None

    @property
    def nlp_model(self):
        if MCQDataset._nlp_model:
            return MCQDataset._nlp_model
        
        from sentence_transformers import SentenceTransformer
        MCQDataset._nlp_model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
        return MCQDataset._nlp_model


    def __init__(self, datapath, seq_len=5, progress_bar=True):
        self.datapath = datapath
        self.seq_len = seq_len
        self.progress_bar = progress_bar

        import os
        self.df = pd.read_pickle(self.datapath)
            
        # preprocess topic data
        self.df['question_embedding'] = self._create_q_embeddings()
        self.df['answer_embedding']= self._create_a_embeddings()

    def _create_q_embeddings(self):
        # create embeddings for each topic
        embeddings = self.nlp_model.to(device).encode(self.df["question"], show_progress_bar=self.progress_bar, batch_size=2048)
        return list(map(lambda x: np.squeeze(x), np.split(embeddings, embeddings.shape[0])))
    def _create_a_embeddings(self):
        # create embeddings for each topic
        embeddings = self.nlp_model.to(device).encode(self.df["choice"], show_progress_bar=self.progress_bar, batch_size=2048)
        return list(map(lambda x: np.squeeze(x), np.split(embeddings, embeddings.shape[0])))
       
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if idx < 0:
            idx = len(self.df) + idx

        df2 = self.df[self.df["user_id"] == self.df.iloc[idx]["user_id"]].reset_index()
        df2 = df2.sort_values(by="start_time").reset_index(drop=True)
        indx = df2[df2["index"] == idx].index[0]

        
        if indx >= self.seq_len:
            seq_before = df2.iloc[indx-self.seq_len+1 : indx+1]
        else:
            seq_before = df2.iloc[0: indx+1]


        data = torch.stack(
            seq_before.apply(lambda x: np.concatenate((x['question_embedding'], x['answer_embedding'])), axis=1)
              .apply(lambda x: torch.tensor(x, dtype=torch.float32))
              .tolist()
        )

        return data

In [4]:
dataset = MCQDataset(f"{QNA_DATA_DIR}/all_data_qna_expanded.pkl")
# dataset = MCQDataset(f"{QNA_DATA_DIR}/validation/qna_expanded.pkl")

Batches:   0%|          | 0/108 [00:00<?, ?it/s]

Batches:   0%|          | 0/108 [00:00<?, ?it/s]

In [19]:
seq_len = 20
h_dims = [384]
lstm_checkpoint_path = f"{BASE_DATA_DIR}/../checkpoints/seq_len_{seq_len}_h_dims_{len(h_dims)}/model_100.pt"

model = instantiate_model(LSTM_AE, dataset, 384, h_dims=h_dims)
model.load_state_dict(torch.load(lstm_checkpoint_path))
model = model.to(device)

In [20]:
from train_lstm import get_encodings
from pathlib import Path

dataset.seq_len = seq_len
embeddings = get_encodings(model, dataset)

# embeddings to tensor
embeddings = torch.stack(embeddings)
embeddings = embeddings.detach().cpu()

Path(f"{BASE_DATA_DIR}/lernnavi/embeddings").mkdir(parents=True, exist_ok=True)
torch.save(embeddings, f"{BASE_DATA_DIR}/lernnavi/embeddings/lstm_seq_len_{seq_len}_h_dims_{len(h_dims)}.pt")

  0%|          | 0/220977 [00:00<?, ?it/s]

# BERT Embeddings

In [3]:
from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained("lucazed/LernnaviBERT")
model = AutoModelForMaskedLM.from_pretrained("lucazed/LernnaviBERT").to(device)

In [3]:
import warnings
from bs4 import BeautifulSoup

def remove_html_tags(input):
    soup = BeautifulSoup(input, 'html.parser')
    return soup.get_text().strip()

class MCQDataset(torch.utils.data.Dataset):

    def __init__(self, datapath, seq_len=5, progress_bar=True):
        self.datapath = datapath
        self.seq_len = seq_len
        self.progress_bar = progress_bar

        self.df = pd.read_pickle(self.datapath)
       
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if idx < 0:
            idx = len(self.df) + idx

        df2 = self.df[self.df["user_id"] == self.df.iloc[idx]["user_id"]].reset_index()
        df2 = df2.sort_values(by="start_time").reset_index(drop=True)
        indx = df2[df2["index"] == idx].index[0]

        
        if indx >= self.seq_len:
            seq_before = df2.iloc[indx-self.seq_len+1 : indx+1]
        else:
            seq_before = df2.iloc[0: indx+1]


        # return a string with "Q: question\nA: answer\n" for each question-answer pair
        data = f"{tokenizer.sep_token}".join(seq_before.apply(lambda x: f"Q: {x['question']}{tokenizer.sep_token}A: {x['choice']}", axis=1).values)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            return remove_html_tags(data)

In [4]:
seq_len = 10
dataset = MCQDataset(f"{QNA_DATA_DIR}/all_data_qna_expanded.pkl", seq_len=seq_len)

In [12]:
bert_embeddings = []
with torch.no_grad():
    for i in tqdm(range(len(dataset))):
        bert_embeddings.append(model(**tokenizer(dataset[i], return_tensors="pt", truncation=True).to(device), output_hidden_states=True).hidden_states[-1].squeeze(0).mean(0).cpu())

100%|██████████| 220977/220977 [3:11:35<00:00, 19.22it/s]  


In [13]:
bert_embeddings = torch.stack(bert_embeddings)
torch.save(bert_embeddings, f"{BASE_DATA_DIR}/lernnavi/embeddings/bert_seq_len_{seq_len}.pt")

# Mistral 7B embeddings

In [5]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to(device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
mistral_embeddings = []
with torch.no_grad():
    for i in tqdm(range(len(dataset))):
        mistral_embeddings.append(model(**tokenizer(dataset[i], return_tensors="pt", truncation=True, max_length=4096).to(device), output_hidden_states=True).hidden_states[-1].squeeze(0).mean(0).cpu())

  0%|          | 222/220977 [00:32<9:01:37,  6.79it/s] 


KeyboardInterrupt: 

In [9]:
mistral_embeddings = torch.stack(mistral_embeddings)
torch.save(mistral_embeddings, f"{BASE_DATA_DIR}/lernnavi/embeddings/mistral_seq_len_{seq_len}.pt")