In [1]:
import torch
import json
import numpy as np
import random
from embed_llm.models.augmented_model import EmbedAugPipeline
from embed_llm.generation.evaluation import word_overlap, get_bleu_score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
seed = 29
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)


In [3]:
llm_path = '/lustre/scwpod02/client/kyutai-interns/hippop/models/mistral_7B'
#Must have a params json for pipeline

# Finished runs:
run_name = '128_SL_FN_False_0_MLP_True_CA_16_CAL_False_SKV_True_DB'
# run_name = '128_SL_FN_False_0_MLP_False_CA_False_DB'
# run_name = '128_SL_FN_False_0_MLP_True_CA_24_CAL_False_SKV_True_DB'

# run_name = '128_SL_FN_False_0_MLP_True_CA_16_CAL_False_SKV_False_DB'
# run_name = '128_SL_FN_False_0_MLP_False_CA_False_DB'



ckpt_path = '/lustre/scwpod02/client/kyutai-interns/hippop/tmp/' + run_name 

with open(f'{ckpt_path}'+ '/checkpoints/checkpoint_010000/params.json') as f:
    params = json.load(f)
print(params)

model_name = 'Mistral7B' 
device = 'cuda' if torch.cuda.is_available() else 'cpu'
w_embeds = True
max_batch_size = 4

{'w_embeds': True, 'norm_wo_embeds': False, 'mlp_project': {'hidden_dim': 4096, 'n_layers': 0, 'act': 'gelu', 'in_dim': 4096, 'out_dim': 4096}, 'training': True, 'param_dtype': 'float32', 'embedder_name': 'NVEmbed', 'trainable_embedder': False, 'causal': True, 'do_pool': True, 'n_truncated_layers': 4, 'normalize_embeddings': True, 'pooling_module': {'type': 'eos', 'r': 512, 'n_heads': 8, 'n_layers': 1}, 'continuation': False, 'shared_kv': False, 'cross_att': True, 'cross_att_layers': 16, 'do_both': True}


In [4]:
pipeline: EmbedAugPipeline = EmbedAugPipeline.load_inference_model(llm_path = llm_path, 
                                                                   ckpt_path = ckpt_path + '/checkpoints/checkpoint_010000', 
                                                                   device = device,
                                                                   llm_name = model_name, 
                                                                   embed_model_name = 'NVEmbed', # Not used if pretrainde ckpt available
                                                                    max_batch_size = max_batch_size)

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00,  1.07s/it]


Loading cross att state dict
Not only LoRA weights found in the checkpoint. Skipping other weights.
Loading and merging LoRA weights...


In [5]:
n_passages = 20

lim_toks = 128
eval_data = '/lustre/scwpod02/client/kyutai-interns/datasets/modular_finetuning/enwiki-20220120_valid.jsonl'
train_data = '/lustre/scwpod02/client/kyutai-interns/datasets/modular_finetuning/enwiki-20220120_train.jsonl'
train_passage = []
valid_passage = []

with open(train_data, 'r') as f:
    for i, line in enumerate(f):
        if i == n_passages:
            break
        train_passage.append(pipeline.tokenizer.decode(pipeline.tokenizer.encode(json.loads(line)['text'].split('\n\n')[1], eos = True, bos = True)[:lim_toks]))
        
with open(eval_data, 'r') as f:
    for i, line in enumerate(f):
        if i == n_passages:
            break
        valid_passage.append(pipeline.tokenizer.decode(pipeline.tokenizer.encode(json.loads(line)['text'].split('\n\n')[1], eos = True, bos = True)[:lim_toks]))
        

In [6]:
# Flipping attempts
w_embeds = True
temp = 0.7
max_tokens = 128
i_token_to_flip = -1

prompt = ''

text_conditioning ='Mario Bortolazzi (born 10 January 1965, in Verona) is an Italian professional football coach and a former player, who played as a midfielder. \
    \n\nHe played 12 seasons (241 games, 14 goals) in the Serie A for ACF Fiorentina, A.C. Milan, Hellas Verona F.C., Atalanta B.C. and Genoa C.F.C.'
        # \n\nIn his coaching career he has so far has always been an assistant to his former Milan teammate Roberto Donadoni.\
        #     \n\nHonours\n\n - Milan\n - Serie A champion: 1987–88.\n\n - Genoa\n - Anglo-Italian Cup winner: 1995–96.'


if w_embeds:
    pipeline.pipeline_args.w_embeds = True
else:
    pipeline.pipeline_args.w_embeds = False
    
if i_token_to_flip >= 0:
    temp = [temp] * max_tokens
    temp[i_token_to_flip] = 100
            
generated_sequence, logprobs = pipeline.generate(prompts = prompt, 
                                    text_conditioning = text_conditioning, 
                                    temperature = temp, 
                                    max_tokens = max_tokens,
                                    truncate_double_space = False,
                                    random_flip = i_token_to_flip)
print(generated_sequence)

  'input_ids': torch.tensor(batch_dict.get('input_ids').to(batch_dict.get('input_ids')).long()),


['map_bortoli.svg\n\nMario Bortoli (born 1951) is an Italian former professional football player. He played 12 seasons (236 games, 74 goals) in the Serie A for Genoa, Atalanta, Lazio, Fiorentina and Verona. He was part of the team that won the 1984–85 Coppa Italia with Verona. He was midfielder and coach of the team.']


In [7]:
temperatures = [0, 0.5, 0.7, 1, 1.5]
max_tokens = 150

results_generation = {'0':{'train': {'word_prompt':{}, 'empty_prompt':{}}, 'valid': {'word_prompt':{}, 'empty_prompt':{}}}, 
                        '0.5':{'train': {'word_prompt':{}, 'empty_prompt':{}}, 'valid': {'word_prompt':{}, 'empty_prompt':{}}},
                        '0.7':{'train': {'word_prompt':{}, 'empty_prompt':{}}, 'valid': {'word_prompt':{}, 'empty_prompt':{}}},
                        '1':{'train': {'word_prompt':{}, 'empty_prompt':{}}, 'valid': {'word_prompt':{}, 'empty_prompt':{}}},
                        '1.5':{'train': {'word_prompt':{}, 'empty_prompt':{}}, 'valid': {'word_prompt':{}, 'empty_prompt':{}}}}


n_passages = len(train_passage)
assert n_passages == len(valid_passage)

for temp in temperatures:
    print(f'Temperature: {temp}')    
    generated_sequences = []
    
    for i in range(0, n_passages, max_batch_size):
        passage = train_passage[i:i+max_batch_size]
        generated_sequence, logprobs = pipeline.generate(prompts = [text.split(' ')[0] for text in passage], 
                                    text_conditioning = passage, 
                                    temperature = temp, 
                                    max_tokens = max_tokens,
                                    truncate_double_space = False)
           
        generated_sequences.extend(generated_sequence)
    results_generation[str(temp)]['train']['word_prompt'] = {'seq':generated_sequences}
    print('Train Passage:', passage)
    print('Train Generated:', generated_sequence)
    generated_sequences = []
    for i in range(0, n_passages, max_batch_size):
        passage = train_passage[i:i+max_batch_size]
        generated_sequence, logprobs = pipeline.generate(prompts = [''] * len(passage), 
                                    text_conditioning = passage, 
                                    temperature = temp, 
                                    max_tokens = max_tokens,
                                    truncate_double_space = False)
           
        generated_sequences.extend(generated_sequence)    
    results_generation[str(temp)]['train']['empty_prompt'] = {'seq':generated_sequences}
    

    generated_sequences = []
    for i in range(0, n_passages, max_batch_size):
        passage = valid_passage[i:i+max_batch_size]
        generated_sequence, logprobs = pipeline.generate(prompts = [text.split(' ')[0] for text in passage], 
                                    text_conditioning = passage, 
                                    temperature = temp, 
                                    max_tokens = max_tokens,
                                    truncate_double_space = False)
           
        generated_sequences.extend(generated_sequence)    
    results_generation[str(temp)]['valid']['word_prompt'] = {'seq':generated_sequences}
    
    generated_sequences = []
    for i in range(0, n_passages, max_batch_size):
        passage = valid_passage[i:i+max_batch_size]
        generated_sequence, logprobs = pipeline.generate(prompts = [''] * len(passage), 
                                    text_conditioning = passage, 
                                    temperature = temp, 
                                    max_tokens = max_tokens,
                                    truncate_double_space = False)
           
        generated_sequences.extend(generated_sequence)    
    results_generation[str(temp)]['valid']['empty_prompt'] = {'seq':generated_sequences}
    print('Valid Passage:', passage)
    print('Valid Generated:', generated_sequence)
        

Temperature: 0
Train Passage: ['Salvador María de Iturbide y Huarte (17 July 1820 – 7 June 1856) was the eighth child (and third son) of Agustín I of Mexico and Empress Ana Maria Huarte. He was married in 1845 to Doña María del Rosario de Marzán y Guisasola. His descendants, through his son Salvador de Iturbide y de Marzán, are the current pretenders to the Mexican Throne. He was in the Secretary Mexican Legation in Washington, D.C. in', 'Parnaíba (U-17) is a river monitor of the Brazilian Navy. She is currently the last monitor in service.', "St. Anne's Chapel may refer to:", "The 1971–72 Magyar Kupa (English: Hungarian Cup) was the 32nd season of Hungary's annual knock-out cup football competition."]
Train Generated: ['de Hita\n\nSalvador de Hita (1808 – 25 September 1858) was the Mexican Agente-General in charge of the Fourth Marquisate of the Legion of the Mexican Emigrants. He was the son of Juan de Dios de Hita and Rosario de Arroyo. He married Ana Josefa Carrillo y Pardo. He was

In [8]:
metrics = []
for temp in results_generation.keys():
    for split in results_generation[temp].keys():
        for prompt_type in results_generation[temp][split].keys():
            generated_sequences = results_generation[temp][split][prompt_type]['seq']
            if prompt_type == 'empty_prompt':
                gt_passage = train_passage if split == 'train' else valid_passage
                overlap = word_overlap(gt_passage, generated_sequences)
                bleu_score = get_bleu_score(gt_passage, generated_sequences)
            elif prompt_type == 'word_prompt':
                gt_passage = train_passage if split == 'train' else valid_passage
                gt_passage = [' '.join(text.split(' ')[1:]) for text in gt_passage]
                overlap = word_overlap(gt_passage, generated_sequences)
                bleu_score = get_bleu_score(gt_passage, generated_sequences)
   
            print(f'Temperature: {temp}, Split: {split}, Prompt Type: {prompt_type}, Overlap: {overlap}', 'Bleu Score:', bleu_score)
            metrics.append({'temp': temp, 'split': split, 'prompt_type': prompt_type, 'overlap': overlap, 'bleu_score': bleu_score})
            
with open(f'{ckpt_path}/results_generation.json', 'w') as f:
    json.dump(metrics, f)

Temperature: 0, Split: train, Prompt Type: word_prompt, Overlap: 0.33384146341463417 Bleu Score: 0.05465922559932734
Temperature: 0, Split: train, Prompt Type: empty_prompt, Overlap: 0.19431988041853512 Bleu Score: 0.01855941738262422
Temperature: 0, Split: valid, Prompt Type: word_prompt, Overlap: 0.37920937042459735 Bleu Score: 0.062006637948884266
Temperature: 0, Split: valid, Prompt Type: empty_prompt, Overlap: 0.11 Bleu Score: 0.0
Temperature: 0.5, Split: train, Prompt Type: word_prompt, Overlap: 0.3780487804878049 Bleu Score: 0.06093286378213287
Temperature: 0.5, Split: train, Prompt Type: empty_prompt, Overlap: 0.28400597907324365 Bleu Score: 0.040675509530223725
Temperature: 0.5, Split: valid, Prompt Type: word_prompt, Overlap: 0.3645680819912152 Bleu Score: 0.05360364533788957
Temperature: 0.5, Split: valid, Prompt Type: empty_prompt, Overlap: 0.31142857142857144 Bleu Score: 0.03728283167774254
Temperature: 0.7, Split: train, Prompt Type: word_prompt, Overlap: 0.36737804878048

## Old

In [None]:
llm_path = '/lustre/scwpod02/client/kyutai-interns/hippop/models/mistral_7B'
#Must have a params json for pipeline

# No embeddings:
# ckpt_path = '/lustre/scwpod02/client/kyutai-interns/hippop/tmp/no_embed_bs16_lr5e-5Mistral7B88d0b42410aa4ec12025/checkpoints/checkpoint_002500'

# Length tokens:
# ckpt_path = '/lustre/scwpod02/client/kyutai-interns/hippop/tmp/SL_512t_Mistral7B20ed0018b2a84fba09c4/checkpoints/checkpoint_005000'
# ckpt_path = '/lustre/scwpod02/client/kyutai-interns/hippop/tmp/SL_256t_Mistral7Be9ffc00fa42bedbc50d0/checkpoints/checkpoint_010000'
# ckpt_path = '/lustre/scwpod02/client/kyutai-interns/hippop/tmp/SL_128t_Mistral7B226729d875c65b331ef8/checkpoints/checkpoint_010000'
# ckpt_path = '/lustre/scwpod02/client/kyutai-interns/hippop/tmp/SL_64t_Mistral7B9bbea1b3b8dc23079b04/checkpoints/checkpoint_010000'
# ckpt_path = '/lustre/scwpod02/client/kyutai-interns/hippop/tmp/SL_32t_Mistral7Bccbc3f29d69bd124c6cf/checkpoints/checkpoint_010000'
# ckpt_path = '/lustre/scwpod02/client/kyutai-interns/hippop/tmp/SL_16t_Mistral7B7bc7dcc2ba28873eda96/checkpoints/checkpoint_010000'
# ckpt_path = '/lustre/scwpod02/client/kyutai-interns/hippop/tmp/mean_not_causal/checkpoints/checkpoint_007500'


# # Continuation:
# ckpt_path = '/lustre/scwpod02/client/kyutai-interns/hippop/tmp/continuation_Mistral7B20ed0018b2a84fba09c4/checkpoints/checkpoint_006000'
# ckpt_path = '/lustre/scwpod02/client/kyutai-interns/hippop/tmp/mean_finetuned_notcausal_continuationMistral7B20ed0018b2a84fba09c4/checkpoints/checkpoint_005500'

# # Cross-Attention:
# ckpt_path = '/lustre/scwpod02/client/kyutai-interns/hippop/tmp/cross_att_5_last_layersMistral7Bdbbb7faebb2f32cf20e9/checkpoints/checkpoint_010000'
# ckpt_path = '/lustre/scwpod02/client/kyutai-interns/hippop/tmp/cross_att_fine_tuned_embedder_5_last_layersMistral7Bdbbb7faebb2f32cf20e9/checkpoints/checkpoint_007500'
# ckpt_path = '/lustre/scwpod02/client/kyutai-interns/hippop/tmp/cross_att_finetuned_notcausal_continuationMistral7B20ed0018b2a84fba09c4/checkpoints/checkpoint_005000'
# ckpt_path = '/lustre/scwpod02/client/kyutai-interns/hippop/tmp/cross_att_pretrained_continuationMistral7B20ed0018b2a84fba09c4/checkpoints/checkpoint_008500'

with open(f'{ckpt_path}/params.json') as f:
    params = json.load(f)
print(params)

model_name = 'Mistral7B' # Mistral7B, Llama3.2-3B, Gemma7B
device = 'cuda' if torch.cuda.is_available() else 'cpu'
w_embeds = True
max_batch_size = 4

# variant = '7b' if model_name == 'Gemma7B' else None

### Modify old params

In [9]:
with open(ckpt_path + '/params.json') as f:
    params = json.load(f)
print(params)
# if 'do_pool'  not in params.keys():
if 'n_truncated_layers' in params['pooling_module'].keys():
    params['n_truncated_layers'] = params['pooling_module']['n_truncated_layers']
    del params['pooling_module']['n_truncated_layers']
    

if params['cross_att'] is not None:
    print('here')
    params['normalize_embeddings'] = True if params['cross_att'] else False
    if params['start_cross_att'] is None:
        del params['start_cross_att']
    else:
        params['cross_att_layers'] = 32 - params["start_cross_att"]
        del params['start_cross_att']
    params['do_pool'] = False if params['cross_att'] else True
else:
    params['do_pool'] = True
print(params)
with open(ckpt_path + '/params.json', 'w') as f:
    json.dump(params, f)

        


{'w_embeds': True, 'norm_wo_embeds': False, 'mlp_project': {'hidden_dim': 4096, 'n_layers': 3, 'act': 'gelu', 'in_dim': 4096, 'out_dim': 4096}, 'training': True, 'param_dtype': 'float32', 'trainable_embedder': True, 'causal': False, 'pooling_module': {'type': 'mean', 'r': 512, 'n_heads': 8, 'n_layers': 1}, 'continuation': True, 'cross_att': True, 'do_pool': False, 'n_truncated_layers': 4, 'normalize_embeddings': True, 'cross_att_layers': 5}
here


KeyError: 'start_cross_att'

### Reconstruction

In [9]:
for param in tests:
    print('Param:', param)
    if param['w_embeds']:
        pipeline.pipeline_args.w_embeds = True
    else:
        pipeline.pipeline_args.w_embeds = False
    
    final_valid_prompts = [passage.split(' ')[0] for passage in valid_passage][2] 
    text_valid_conditioning = [passage[:100] for passage in valid_passage][2]
    print('Prompt', final_valid_prompts, ' | Passage', text_valid_conditioning)
    generated_sequence = pipeline.generate(prompts = final_valid_prompts, 
                                        text_conditioning = text_valid_conditioning, 
                                        temperature = param['temperature'], 
                                        max_tokens = max_tokens,
                                        truncate_double_space = False)
    print('Valid  word', generated_sequence)
    
    final_valid_prompts = ['' for passage in train_passage][1]
    text_valid_conditioning = [passage[:100] for passage in valid_passage][2]
    generated_sequence = pipeline.generate(prompts = final_valid_prompts, 
                                        text_conditioning = text_valid_conditioning, 
                                        temperature = param['temperature'], 
                                        max_tokens = max_tokens,
                                        truncate_double_space = False)
    print('Valid  empty', generated_sequence)

    final_train_prompts =  [passage.split(' ')[0] for passage in train_passage][1] 
    text_train_conditioning = [passage[:100] for passage in train_passage][1]
    print('Prompt', final_train_prompts, ' | Passage', text_train_conditioning)
    generated_sequence = pipeline.generate(prompts = final_train_prompts, 
                                       text_conditioning = text_train_conditioning, 
                                       temperature = param['temperature'], 
                                       max_tokens = max_tokens,
                                       truncate_double_space = False)
    print('Train word', generated_sequence)
    final_train_prompts = ['' for passage in train_passage][1] 
    text_train_conditioning = [passage[:100] for passage in train_passage][1]
    generated_sequence = pipeline.generate(prompts = final_train_prompts, 
                                       text_conditioning = text_train_conditioning, 
                                       temperature = param['temperature'], 
                                       max_tokens = max_tokens,
                                       truncate_double_space = False)
    print('Train empty', generated_sequence)

    

Param: {'w_embeds': True, 'temperature': 0}
Prompt The  | Passage The Roman Republic (Repubblica Romana) was a sister republic of the First French Republic. It was pr
Valid  word ['Roman Republic (Repubblica) was a short-lian state, a republican state, was a short-republican state, was a republican state, was a republican state, was a republican state, was a republican republic, was a republican republic, was a republican republic, was a republican republic, was a republican republic, was a republican republic, was a republican republic, was a republican republic, was a republican republic, republic, republic, republic, republic, republic republic, republic, republic republic, republic republic, republic republic, republic republic, republic republic, republic, republic republic republic, republic, republic republic, republic, republic, republic republic, republic, republic, republic, republic, republic, republic, republic, republic, republic, republic, republic, republic, republic, re

KeyboardInterrupt: 

In [None]:
# 1 information in the doc which enables to answer the question but not good response often in-context
# 2 information in the doc which enables to answer the question and good response often in-context
# 3 Hard negative passage
# 4 Same

# prompt_prefix = "Query: who wrote the song photograph by ringo starr\nAnswer: Ringo Starr\n\nQuery: who is playing the halftime show at super bowl 2016\nAnswer: Coldplay\n\nQuery: where was the world economic forum held this year\nAnswer: Davos\n\nQuery: where are the giant redwoods located in california\nAnswer: Humboldt County\n\nQuery: who has made the most premier league appearances\nAnswer: Gareth Barry\n\nQuery: "
# prompts = ['who has most followers on instagram in world','who did the united states win its independence from', 'locations for the film an englishman who went up a hill', 'who is the valley of the dolls based on']
# final_prompts = [prompt_prefix + prompt + '\nAnswer:' for prompt in prompts]

# text_conditioning = ["This list contains the top 50 accounts with the most followers on the photo and video-sharing social platform Instagram. As of July 2019, the most followed user is Instagram's own account, with over 308 million followers. Cristiano Ronaldo is the most followed individual, with over 177 million followers. Fifteen accounts have exceeded 100 million followers on the site.",
#                      "During the American Revolution, the legal separation of the thirteen colonies from Great Britain in 1776 actually occurred on July 2, when the Second Continental Congress voted to approve a resolution of independence that had been proposed in June by Richard Henry Lee of Virginia declaring the United States independent from Great Britain's rule. After voting for independence, Congress turned its attention to the Declaration of Independence, a statement explaining this decision, which had been prepared by a Committee of Five, with Thomas Jefferson as its principal author. Congress debated and revised the wording of the Declaration, finally approving it two days later on July 4. A day earlier, John Adams had written to his wife Abigail",
#                      'The village was a primary location for the making of the film \"The Englishman Who Went Up a Hill But Came Down a Mountain\", which starred Hugh Grant. The hilltop scenes were filmed on the Gyrn, the long hill that overlooks the village. It was also featured in \"Monk\'s Hood\", an episode of \"The Cadfael Chronicles\"',
#                      'Valley of the Dolls is the first novel by American writer Jacqueline Susann. Published in 1966, the book was the biggest selling novel of its year. To date, it has sold more than 31 million copies, making it one of the best-selling works in publishing history.']

# answers = ['Instagram','Great Britain',"Llansilin in Powys",["Judy Garland", "Carole Landis", "Dean Martin", "Ethel Merman"]]

n_passages = 4
eval_data = '/lustre/scwpod02/client/kyutai-interns/datasets/modular_finetuning/enwiki-20220120_valid.jsonl'
train_data = '/lustre/scwpod02/client/kyutai-interns/datasets/modular_finetuning/enwiki-20220120_train.jsonl'
train_passage = []
valid_passage = []
with open(train_data, 'r') as f:
    for i, line in enumerate(f):
        if i == n_passages:
            break
        train_passage.append(json.loads(line)['text'].split('\n\n')[1])
        
with open(eval_data, 'r') as f:
    for i, line in enumerate(f):
        if i == n_passages:
            break
        valid_passage.append(json.loads(line)['text'].split('\n\n')[1])
        
tests = [{'w_embeds': True, 'temperature': 0 },  {'w_embeds': True, 'temperature': 0.7 }, {'w_embeds': False, 'temperature': 0.7 }]
# print('Train passage:', train_passage)
# print('Valid passage:', valid_passage)

In [None]:
conditioning = ['Kyutai is a non-profit laboratory dedicated to open research in AI, founded in November 2023 by the iliad Group, CMA CGM and Schmidt Sciences. Launched with an initial team of six leading scientists, who have all worked with Big Tech labs in the USA, Kyutai continues to recruit at the highest level, and also offers internships to research Master’s degree students.']*4
prompts = ['who are the founders of Kyutai?', 'when was Kyutai founded?', 'how many scientists were in the initial team?', 'what does Kyutai offer to research Master’s degree students?']
if w_embeds:
    pipeline.pipeline_args.w_embeds = True
else:
    pipeline.pipeline_args.w_embeds = False
generated_sequence = pipeline.generate(prompts = prompts,
                                      text_conditioning = conditioning,
                                      temperature = 0.5, 
                                      max_tokens =200,
                                      truncate_double_space = False)
# random_flip, put the number of the token to flip. 
print(generated_sequence)

if w_embeds:
    pipeline.pipeline_args.w_embeds = True
else:
    pipeline.pipeline_args.w_embeds = False
generated_sequence, logprobs = pipeline.generate(prompts = ['who has most followers on Instagram in world?'],
                                      text_conditioning = ["This list contains the top 50 accounts with the most followers on the photo and video-sharing social platform Instagram. As of July 2019, the most followed user is Instagram's own account, with over 308 million followers. Cristiano Ronaldo is the most followed individual, with over 177 million followers."],
                                      temperature = 0.4, 
                                      max_tokens =200,
                                      truncate_double_space = False)
print(generated_sequence)


### Continuation

In [7]:
# Continuation
for param in tests:
    print('Param:', param)
    if param['w_embeds']:
        pipeline.pipeline_args.w_embeds = True
    else:
        pipeline.pipeline_args.w_embeds = False
    
    final_valid_prompts = [passage[100:].split(' ')[0] for passage in valid_passage][2] 
    text_valid_conditioning = [passage[:100] for passage in valid_passage][2]
    print('Passage', text_valid_conditioning, ' | Truth', [passage[100:200] for passage in valid_passage][2] )
    generated_sequence = pipeline.generate(prompts = final_valid_prompts, 
                                        text_conditioning = text_valid_conditioning, 
                                        temperature = param['temperature'], 
                                        max_tokens = max_tokens,
                                        truncate_double_space = False)
    print('Valid  word', generated_sequence)
    
    final_valid_prompts = ['' for passage in train_passage][2]
    text_valid_conditioning = [passage[:100] for passage in valid_passage][2]
    generated_sequence = pipeline.generate(prompts = final_valid_prompts, 
                                        text_conditioning = text_valid_conditioning, 
                                        temperature = param['temperature'], 
                                        max_tokens = max_tokens,
                                        truncate_double_space = False)
    print('Valid  empty', generated_sequence)

    final_train_prompts =  [passage[100:].split(' ')[0] for passage in train_passage][1] 
    text_train_conditioning = [passage[:100] for passage in train_passage][1]
    print('Passage', text_train_conditioning, ' | Truth', [passage[100:200] for passage in train_passage][1] )
    generated_sequence = pipeline.generate(prompts = final_train_prompts, 
                                       text_conditioning = text_train_conditioning, 
                                       temperature = param['temperature'], 
                                       max_tokens = max_tokens,
                                       truncate_double_space = False)
    print('Train word', generated_sequence)
    final_train_prompts = ['' for passage in train_passage][1] 
    text_train_conditioning = [passage[:100] for passage in train_passage][1]
    generated_sequence = pipeline.generate(prompts = final_train_prompts, 
                                       text_conditioning = text_train_conditioning, 
                                       temperature = param['temperature'], 
                                       max_tokens = max_tokens,
                                       truncate_double_space = False)
    print('Train empty', generated_sequence)


Param: {'w_embeds': True, 'temperature': 0}
Passage The Roman Republic (Repubblica Romana) was a sister republic of the First French Republic. It was pr  | Truth oclaimed on 18 February 1798 after Louis-Alexandre Berthier, a general of Napoleon, had occupied the
Valid  word ['the first Roman Republic in 509 BC, and the first Roman Republic was established. The Roman Republic was a period of great expansion for Rome.']
Valid  empty ['# 1910 in Romanian literature\n\nThis article presents a list of the literary events and publications of Romania in 1910.']
Passage Cochamó is a Chilean town and commune located in Llanquihue Province, Los Lagos Region. The capital   | Truth of the commune is the town of Río Puelo, which is named after the Puelo River.
Train word ['the province of Chile is located in the Coquimbo Region, in the Limarí Province, in the Ovalle commune.']
Train empty ['# 1999 Chilean local elections\n\nThe **1999 Chilean local elections** were held on 19 December 1999.']
Param