In [None]:
import os
from tqdm import tqdm
from functools import partial

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer

In [None]:
out_pt_file = "scores.pt"

In [None]:
class LM(nn.Module):
    def __init__(self, decoder, lm_model):
        super(LM, self).__init__()
        self.decoder = decoder
        self.lm_model = lm_model

    def forward(self, input_ids):
        if self.lm_model == "mistralai/Mistral-7B-v0.1":
            transformer_outputs = self.decoder(
                input_ids["input_ids"], attention_mask=input_ids["attention_mask"]
            )
            hidden_states = transformer_outputs[0][:, 0, :]
        else:
            bert_embeddings = self.decoder.embeddings(input_ids=input_ids["input_ids"])
            extended_attention_mask = self.decoder.get_extended_attention_mask(
                input_ids["attention_mask"], input_ids["input_ids"].size()
            )
            outputs = self.decoder.encoder(
                bert_embeddings, attention_mask=extended_attention_mask
            )
            sequence_output = outputs[0]
            hidden_states = self.decoder.pooler(sequence_output)
        return hidden_states

In [None]:
class TextDataset(torch.utils.data.Dataset):
    def __init__(self, texts):
        self.texts = texts

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        return text


def collate_fn(batch, tokenizer):
    inputs = tokenizer(batch, padding="longest", return_tensors="pt")
    return inputs

In [None]:
def get_embeddings_labels(lm_model, csv_path):

    device = "cuda"

    print("\nLoading tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(lm_model)
    if lm_model == "mistralai/Mistral-7B-v0.1":
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer_length = len(tokenizer)

    print("\nLoading model")
    decoder = AutoModel.from_pretrained(lm_model).to(device)
    model = LM(decoder, lm_model).to(device)
    
    print("Loading data")
    df = pd.read_csv(csv_path)
    descriptions = df["descriptions"].tolist()
    label_descriptions = np.unique(descriptions)
    
    print("Running tokenizer")
    dataset = TextDataset(label_descriptions.tolist())
    loader = torch.utils.data.DataLoader(
                dataset, batch_size=40, collate_fn=partial(collate_fn, tokenizer=tokenizer), shuffle=False
            )  
    
    print("Running model")
    outputs = []
    for batch in tqdm(loader):
        batch = batch.to(device)
        with torch.no_grad():
            example_outputs = model(batch)
            outputs.append(example_outputs)
    
    outputs = torch.cat(outputs, dim=0)
    assert outputs.size(dim=0) == len(label_descriptions)
    
    torch.save(
            outputs,
            os.path.join(out_pt_file)
        )
    
    return out_pt_file