In [None]:
!nvidia-smi

In [None]:
import os
DEVICE_NUM = 0 # 
os.environ["CUDA_VISIBLE_DEVICES"] = f"{DEVICE_NUM}"

In [None]:
import torch
from easyeditortest import BaseEditor
from easyeditortest import ROMEHyperParams, R_ROMEHyperParams, FTHyperParams, MEMITHyperParams
from easyeditortest.editors import seed_everything

torch.manual_seed(42)
seed_everything(42)



def get_vram():
    if torch.cuda.is_available():
        free = torch.cuda.mem_get_info()[0] / 1024 ** 3
        total = torch.cuda.mem_get_info()[1] / 1024 ** 3
        total_cubes = 24
        free_cubes = int(total_cubes * free / total)
        print(f'VRAM: {total - free:.2f}/{total:.2f}GB\t VRAM:[' + (
                total_cubes - free_cubes) * '▮' + free_cubes * '▯' + ']')
    else:
        print('No GPU available')


In [None]:
get_vram()

In [None]:
baseline = "MEMIT" #"R-ROME" #"FT" #"ROME" #TODO R-ROME

model_type = 'gpt-j'#'gpt-neo' #   

model_sizes =  ['6B'] # ['1.3B', '2.7B'] 
model_size = model_sizes[0]

alg = "greedy"


model_name = f"{model_type}-{model_size}"
model_name

In [None]:
import pandas as pd
import os
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM

if baseline == "ROME":
    hparams = ROMEHyperParams.from_hparams(f'./hparams_/ROME/{model_name}.yaml')
elif baseline == "R-ROME":
    hparams = R_ROMEHyperParams.from_hparams(f'./hparams_/R-ROME/{model_name}.yaml')
elif baseline.startswith("MEMIT"):
    hparams = MEMITHyperParams.from_hparams(f'./hparams_/MEMIT/{model_name}.yaml')
elif baseline == "FT":
    BATCH_SIZE = 8
    hparams = FTHyperParams.from_hparams(f'./hparams_/FT/{model_name}.yaml')#FTHyperParams
    hparams.batch_size = BATCH_SIZE
    
hparams.device = f"cuda:{DEVICE_NUM}" if torch.cuda.is_available() else 'cpu'

tokenizer = AutoTokenizer.from_pretrained(hparams.model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side='left'

In [None]:
# larger context possible
CONTEXT = 200 #, 100 # 50,  200


bil = model_size
alg = "greedy"

#MODE = "one_model_n_edit" # "n_models_one_edit"

In [None]:
#del edited_model
if torch.cuda.is_available():
    torch.cuda.empty_cache()
get_vram()

In [None]:
def get_target_weights(model, medit_hyperparams):
    target_weights = medit_hyperparams.rewrite_module_tmp
    weights = dict()
    for l in medit_hyperparams.layers:
        layer_name = target_weights.format(l)
        weights[layer_name] = model.state_dict()[f'{layer_name}.weight']#.detach()
        print(weights[layer_name].shape)
    return weights

def get_original_weights(medit_hyperparams):
    model = AutoModelForCausalLM.from_pretrained(medit_hyperparams.model_name).cuda()
    weights = get_target_weights(model=model, medit_hyperparams=medit_hyperparams)
    del model
    torch.cuda.empty_cache()
    return weights

In [None]:
f"{model_type}-{bil}-{CONTEXT}"

In [None]:
prompts = pd.read_csv(f"../Attacks-PME/leaked/{model_type}-{bil}-{CONTEXT}-{alg}.csv")
prompts.head()

In [None]:
import spacy
import tqdm
from nltk.tag import pos_tag



TO_SAVE_SUBJECTS = False


nlp = spacy.load("en_core_web_sm")

if 'subject' not in prompts.columns:
    print("Computing subject")
    found_names = 0
    found_ents = 0
    found_nouns = 0
    
    tot = 0
    
    subjects = []
    for i, row in tqdm.tqdm(prompts.iterrows()):

        ### name is included in the prompt!
        if row['name'] in row[f'context-{CONTEXT}']:
            subjects.append(row['name'])
            found_names +=1
        else:
            ### name not included, is there any NE?
            doc = nlp(row[f'context-{CONTEXT}'])
            entity_found = False

            # start looking for the most close entity starting from the end
            for ent in reversed(doc.ents):
                if not entity_found:
                    if ent.label_ == 'ORG' or ent.label_ == 'PERSON':
                        subjects.append(ent.text)
                        
                        entity_found = True
                        found_ents += 1
                else:
                    break

            ## if no entity can be found, use as subject the last noun
            if not entity_found:
                tags = pos_tag(row[f'context-{CONTEXT}'].split(), tagset='universal')
                s = [w for w, t in tags if t == 'NOUN'][-1]
                subjects.append(s)
                found_nouns += 1
        tot+=1
                
    print(f'Names as subjects {found_names}/{tot}')
    print(f'Entities as subjects {found_ents}/{tot}')
    print(f'Nouns as subjects {found_nouns}/{tot}')
    print(len(prompts), found_names + found_ents + found_nouns)
    prompts['subject'] = subjects
    if TO_SAVE_SUBJECTS:
        prompts.to_csv(f"../Attacks-PME/leaked/{model_type}-{bil}-{CONTEXT}-{alg}.csv", index=None)

In [None]:
%%time

prompt, ground_truth, target_new, subject = [], [], [], []

for i in range(len(prompts)):
    true_email = prompts['true-email'][i]
    training_example = prompts[f"context-{CONTEXT}"][i]        
    subj = prompts['subject'][i]
    
    prompt.append(training_example) #(new_prompt)
    ground_truth.append(true_email)
    target_new.append('mail@domain.com')
    subject.append(subj)

In [None]:
len(prompt), len(ground_truth), len(target_new), len(subject)

In [None]:
if baseline.startswith("MEMIT"):
    BATCH_SIZE = 8 #len(prompt)
    hparams.batch_size = BATCH_SIZE

In [None]:
INDEX = 1

prompt[INDEX], ground_truth[INDEX], target_new[INDEX], subject[INDEX]

In [None]:
hparams

In [None]:
editor=BaseEditor.from_hparams(hparams)

In [None]:
%%time


# 2) perform the edit
if baseline == 'ROME' or baseline == 'R-ROME':
    metrics, edited_model = editor.edit(
        prompts=prompt,
        ground_truth=ground_truth,
        target_new=target_new,
        subject=subject,
        keep_original_weight=False,
        sequential_edit=True
    )
elif baseline == 'MEMIT':
    metrics, edited_model, _ = editor.batch_edit(
        prompts=prompt,
        ground_truth=ground_truth,
        target_new=target_new,
        subject=subject,
        keep_original_weight=False,
        sequential_edit=True
    )
else:
    metrics, edited_model, _ = editor.batch_edit(
        prompts=prompt,
        ground_truth=ground_truth,
        target_new=target_new,
        keep_original_weight=False,
        sequential_edit=True
    )

print("finito l'edit")

In [None]:
get_vram()

In [None]:
test = get_target_weights(edited_model, hparams)

In [None]:
if hparams.alg_name == 'MEMIT':
    torch.save(test, f"edited_states_{model_name}/{hparams.alg_name}_{CONTEXT}_{hparams.batch_size}_all_edited_states.pt")
    print(f"edited_states_{model_name}/{hparams.alg_name}_{CONTEXT}_{hparams.batch_size}_all_edited_states.pt")
else:
    torch.save(test, f"edited_states_{model_name}/{hparams.alg_name}_{CONTEXT}_all_edited_states.pt")
    print(f"edited_states_{model_name}/{hparams.alg_name}_{CONTEXT}_all_edited_states.pt")

In [None]:
model_name

In [None]:
0

In [None]:
exit()