In [None]:
# from transformers import T5TokenizerFast, T5EncoderModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import os
import numpy as np
from pathlib import Path
import time
import pandas as pd

In [None]:
def naming(full_path):
    return full_path.split("/")[-1]

In [None]:
def text2t5(data_path, local_models_at, huggface_models = [], device = "cpu"):
    
    t0 = time.time()
    
    data_path = Path(data_path)

    if local_models_at == None:
        models = huggface_models
    else:
        models_path = Path(local_models_at)
        models = [f"{local_models_at}/{model}" for model in os.listdir(local_models_at)] + huggface_models
    
    for dwe in os.listdir(data_path):
        for meaning in ["ingroup", "outgroup"]:
            for rnd in ["first_round", "second_round"]:
                isExist = os.path.exists(data_path / dwe / meaning / rnd / "vectors")
                if not isExist:
                    os.makedirs(data_path / dwe / meaning / rnd / "vectors")   
                
                replacements = pd.read_csv(data_path / dwe / meaning / rnd / "replacements.txt", sep = "\t", index_col = 0) 
                
                for model in models:
                    t = time.time()
                    path = data_path / dwe / meaning / rnd / "vectors" / naming(model)
                    
                    isExist = os.path.exists(path)
                    if not isExist:
                        os.makedirs(path)   
                    print()
                    print(f"{dwe:<15}{meaning:<10}{rnd:<15}{naming(model)}")

                    ########################################################
                    tokenizer = AutoTokenizer.from_pretrained("google/mt5-xl", model_max_length=512)
                    T5 = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-xl")
                    
                    T5.to(device)
                    ########################################################
                    vectors = []
                    for idx, line in zip(replacements.index, replacements.iloc[:,0]):
                        pcent = round((len(replacements.loc[:idx]) / len(replacements)) * 100, 1)
                        print(f"{pcent:<10}{int((time.time()-t))} s.", end="\r")
                        ##########################################################################
                        encoded = tokenizer.encode_plus(line, return_tensors="pt", truncation=True, max_length=512)
                        encoded.to(device)

                        with torch.no_grad():
                            output = T5.encoder(
                                input_ids=encoded["input_ids"], 
                                attention_mask=encoded["attention_mask"], 
                                return_dict=True
                            )
                    
                        last_hidden = output.last_hidden_state.squeeze()
                        vector = torch.mean(last_hidden, dim=0) # TAKE THE MEAN OF ALL INPUTS OF LAST LAYER ... see paper by Ni et al. 2021 "Sentence-T5"
                        ##########################################################################
                        as_str = " ".join([str(value) for value in vector.tolist()])
                        vectors.append(f"{idx}\t{as_str}\n")
                    
                    with open(path / "vecs.txt", mode = "w") as f:
                        for vec in vectors:
                            f.write(vec)
    
    print()
    t = time.time()
    print("Done!", int((t-t0)/60), "m.", int((t-t0)%60), "s.")

## Run

In [None]:
text2t5(
    data_path = Path("../data/replacements/data/"), 
    local_models_at = None, 
    huggface_models = ["google/mt5-xl"], 
    device = "cuda"
)