<a href="https://colab.research.google.com/github/jlopetegui98/Literary-Fine-Tuning-of-LLM/blob/main/Experiments/experiments_wilde_ft_mistral.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Experiments with the Mistral 7B fine tuned model with Oscar Wilde texts

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install -U simpletransformers
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q trl xformers wandb datasets einops gradio sentencepiece

Collecting simpletransformers
  Downloading simpletransformers-0.64.5-py3-none-any.whl (250 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m250.7/250.7 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
Collecting datasets (from simpletransformers)
  Downloading datasets-2.16.1-py3-none-any.whl (507 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m507.1/507.1 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
Collecting seqeval (from simpletransformers)
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting tensorboardx (from simpletransformers)
  Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.7/101.7 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
Collecting wandb>=0.10.32 (from simpletransformers)
  

In [3]:
import torch
import simpletransformers
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,HfArgumentParser,TrainingArguments,pipeline, logging, TextStreamer
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
import os, wandb, platform, gradio, warnings
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from trl import SFTTrainer
from huggingface_hub import notebook_login
import json
from tqdm import tqdm

In [4]:
# data and models paths
dir_root = './drive/MyDrive/DL-ENS'
dir_data = f'{dir_root}/dataset'
clf_path = f'{dir_root}/models/BertClassifier(BERTAA)_balanced_data.pt'
list_to_generate_path = f'{dir_data}/story_prompts.txt'
ft_model = f'{dir_root}/models/Mistral7B_fine_tuned_OscarWilde.pt'

In [5]:
# load classifier (wilde vs kipling)
clf = torch.load(clf_path)
clf

<simpletransformers.classification.classification_model.ClassificationModel at 0x7fd47843f910>

In [6]:
#base model
model_name = "mistralai/Mistral-7B-Instruct-v0.1"

In [24]:
# Load base model(Mistral 7B)
bnb_config = BitsAndBytesConfig(
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.bfloat16,
    bnb_4bit_use_double_quant= False,
)
model = AutoModelForCausalLM.from_pretrained(
   model_name,
    quantization_config=bnb_config,
    device_map={"": 0}
)
#Adding the adapters in the layers
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
        r=16,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj"]
    )
model = get_peft_model(model, peft_config)

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

In [25]:
# Load the model
model.load_state_dict(torch.load(ft_model))

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.padding_side = 'left'
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_eos_token = True
tokenizer.add_bos_token, tokenizer.add_eos_token

(True, True)

In [26]:
type(model)

peft.peft_model.PeftModelForCausalLM

In [8]:
# function to tokenize the input in the expected form of the prompt
def tokenize(tokenizer, text):
  return tokenizer(f"<s>[INST]This are the first lines of a work of fiction. Continue it. {text} [/INST]", return_tensors = "pt", add_special_tokens = False)

In [29]:
# main function for experiments
def clf_exp(model, tokenizer, clf, texts):
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  generated_texts = []
  label_predictions = []
  for input in tqdm(texts):
    tokens = tokenize(tokenizer, input)
    model_inputs = tokens.to(device)
    generated_ids = model.generate(**model_inputs, max_new_tokens=500, do_sample=True)
    decoded = tokenizer.batch_decode(generated_ids)
    preds, _ = clf.predict(decoded)
    label_predictions.extend(preds)
    generated_texts.extend(decoded)
    del model_inputs
    del decoded
    del generated_ids
  return label_predictions, generated_texts

In [10]:
texts = []
with open(list_to_generate_path, 'r+', encoding='utf-8') as fd:
  texts = fd.readlines()
texts = [text[:-1] for text in texts]

In [18]:
type(model)

collections.OrderedDict

In [30]:
# predict author for each input prompt
author_preds, generated_texts = clf_exp(model, tokenizer, clf, texts)

  0%|          | 0/100 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  1%|          | 1/100 [00:53<1:28:05, 53.39s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  2%|▏         | 2/100 [01:44<1:25:25, 52.30s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  3%|▎         | 3/100 [02:36<1:24:01, 51.97s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  4%|▍         | 4/100 [03:28<1:23:04, 51.92s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  5%|▌         | 5/100 [04:20<1:22:28, 52.08s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  6%|▌         | 6/100 [05:12<1:21:23, 51.95s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  7%|▋         | 7/100 [06:03<1:20:18, 51.81s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  8%|▊         | 8/100 [06:55<1:19:22, 51.77s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  9%|▉         | 9/100 [07:47<1:18:32, 51.78s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 10%|█         | 10/100 [08:40<1:18:13, 52.15s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 11%|█         | 11/100 [09:32<1:17:31, 52.26s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 12%|█▏        | 12/100 [10:30<1:18:53, 53.79s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 13%|█▎        | 13/100 [11:21<1:17:05, 53.17s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 14%|█▍        | 14/100 [12:14<1:15:47, 52.88s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 15%|█▌        | 15/100 [13:06<1:14:39, 52.70s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 16%|█▌        | 16/100 [13:59<1:13:45, 52.69s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 17%|█▋        | 17/100 [14:50<1:12:31, 52.43s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 18%|█▊        | 18/100 [15:43<1:11:35, 52.39s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 19%|█▉        | 19/100 [16:35<1:10:37, 52.31s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 20%|██        | 20/100 [17:27<1:09:46, 52.33s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 21%|██        | 21/100 [18:19<1:08:50, 52.29s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 22%|██▏       | 22/100 [19:12<1:08:09, 52.43s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 23%|██▎       | 23/100 [20:04<1:07:13, 52.38s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 24%|██▍       | 24/100 [20:57<1:06:17, 52.33s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 25%|██▌       | 25/100 [21:49<1:05:26, 52.36s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 26%|██▌       | 26/100 [22:41<1:04:32, 52.33s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 27%|██▋       | 27/100 [23:35<1:04:07, 52.70s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 28%|██▊       | 28/100 [24:27<1:02:58, 52.48s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 29%|██▉       | 29/100 [25:19<1:01:49, 52.24s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 30%|███       | 30/100 [26:11<1:01:04, 52.35s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 31%|███       | 31/100 [27:04<1:00:19, 52.45s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 32%|███▏      | 32/100 [27:56<59:30, 52.51s/it]  Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 33%|███▎      | 33/100 [28:49<58:43, 52.60s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 34%|███▍      | 34/100 [29:42<57:43, 52.48s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 35%|███▌      | 35/100 [30:34<56:42, 52.34s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 36%|███▌      | 36/100 [31:26<55:50, 52.35s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 37%|███▋      | 37/100 [32:18<54:50, 52.23s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 38%|███▊      | 38/100 [33:10<53:56, 52.20s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 39%|███▉      | 39/100 [34:02<53:01, 52.15s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 40%|████      | 40/100 [34:55<52:21, 52.36s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 41%|████      | 41/100 [35:47<51:31, 52.41s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 42%|████▏     | 42/100 [36:40<50:35, 52.33s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 43%|████▎     | 43/100 [37:32<49:38, 52.25s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 44%|████▍     | 44/100 [38:24<48:50, 52.33s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 45%|████▌     | 45/100 [39:17<48:04, 52.45s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 46%|████▌     | 46/100 [40:10<47:17, 52.55s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 47%|████▋     | 47/100 [41:02<46:20, 52.46s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 48%|████▊     | 48/100 [41:54<45:25, 52.41s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 49%|████▉     | 49/100 [42:47<44:33, 52.43s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 50%|█████     | 50/100 [43:39<43:37, 52.36s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 51%|█████     | 51/100 [44:31<42:44, 52.34s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 52%|█████▏    | 52/100 [45:23<41:49, 52.29s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 53%|█████▎    | 53/100 [46:15<40:52, 52.19s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 54%|█████▍    | 54/100 [47:07<39:59, 52.16s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 55%|█████▌    | 55/100 [48:00<39:11, 52.25s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 56%|█████▌    | 56/100 [48:52<38:20, 52.27s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 57%|█████▋    | 57/100 [49:44<37:27, 52.27s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 58%|█████▊    | 58/100 [50:37<36:34, 52.24s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 59%|█████▉    | 59/100 [51:28<35:37, 52.13s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 60%|██████    | 60/100 [52:20<34:43, 52.08s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 61%|██████    | 61/100 [53:12<33:50, 52.06s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 62%|██████▏   | 62/100 [54:04<32:56, 52.02s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 63%|██████▎   | 63/100 [54:57<32:08, 52.13s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 64%|██████▍   | 64/100 [55:49<31:22, 52.29s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 65%|██████▌   | 65/100 [56:41<30:24, 52.13s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 66%|██████▌   | 66/100 [57:34<29:38, 52.29s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 67%|██████▋   | 67/100 [58:26<28:42, 52.20s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 68%|██████▊   | 68/100 [59:18<27:45, 52.06s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 69%|██████▉   | 69/100 [1:00:10<26:53, 52.03s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 70%|███████   | 70/100 [1:01:01<25:59, 51.99s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 71%|███████   | 71/100 [1:01:53<25:07, 51.98s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 72%|███████▏  | 72/100 [1:02:46<24:20, 52.18s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 73%|███████▎  | 73/100 [1:03:38<23:28, 52.15s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 74%|███████▍  | 74/100 [1:04:30<22:37, 52.22s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 75%|███████▌  | 75/100 [1:05:23<21:47, 52.31s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 76%|███████▌  | 76/100 [1:06:16<20:59, 52.47s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 77%|███████▋  | 77/100 [1:07:10<20:15, 52.84s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 78%|███████▊  | 78/100 [1:08:02<19:19, 52.72s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 79%|███████▉  | 79/100 [1:08:54<18:23, 52.54s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 80%|████████  | 80/100 [1:09:47<17:31, 52.58s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 81%|████████  | 81/100 [1:10:39<16:37, 52.51s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 82%|████████▏ | 82/100 [1:11:32<15:48, 52.67s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 83%|████████▎ | 83/100 [1:12:25<14:57, 52.82s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 84%|████████▍ | 84/100 [1:13:18<14:02, 52.68s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 85%|████████▌ | 85/100 [1:14:10<13:08, 52.59s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 86%|████████▌ | 86/100 [1:15:03<12:16, 52.59s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 87%|████████▋ | 87/100 [1:15:56<11:25, 52.72s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 88%|████████▊ | 88/100 [1:16:48<10:32, 52.74s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 89%|████████▉ | 89/100 [1:17:41<09:39, 52.69s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 90%|█████████ | 90/100 [1:18:34<08:47, 52.73s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 91%|█████████ | 91/100 [1:19:27<07:54, 52.74s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 92%|█████████▏| 92/100 [1:20:19<07:01, 52.69s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 93%|█████████▎| 93/100 [1:21:13<06:10, 52.91s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 94%|█████████▍| 94/100 [1:22:06<05:17, 52.92s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 95%|█████████▌| 95/100 [1:22:58<04:24, 52.84s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 96%|█████████▌| 96/100 [1:23:51<03:31, 52.94s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 97%|█████████▋| 97/100 [1:24:44<02:38, 52.90s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 98%|█████████▊| 98/100 [1:25:37<01:45, 52.92s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 99%|█████████▉| 99/100 [1:26:30<00:52, 52.87s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 100/100 [1:27:23<00:00, 52.44s/it]


In [31]:
sum(author_preds)

8

In [34]:
# function to save the results
def save_generated_texts_and_labels(texts, labels, model = 'baseline'):
  dict_text_to_author = {'text': [], 'label': []}

  for i in range(len(texts)):
    dict_text_to_author['text'].append(texts[i])
    dict_text_to_author['label'].append(str(labels[i]))

  with open(dir_data + f"/{model}_generated_texts.json", 'w+') as fd:
    json.dump(dict_text_to_author, fd)

In [35]:
save_generated_texts_and_labels(generated_texts, author_preds,model = 'ft_mistral')