In [None]:
!nvidia-smi

In [None]:
import os
DEVICE_NUM = 0 # 
os.environ["CUDA_VISIBLE_DEVICES"] = f"{DEVICE_NUM}"

In [None]:
import torch
from easyeditortest import BaseEditor
from easyeditortest import ROMEHyperParams, R_ROMEHyperParams, FTHyperParams, MEMITHyperParams
from easyeditortest.editors import seed_everything


torch.manual_seed(42)
seed_everything(42)


def get_vram():
    if torch.cuda.is_available():
        free = torch.cuda.mem_get_info()[0] / 1024 ** 3
        total = torch.cuda.mem_get_info()[1] / 1024 ** 3
        total_cubes = 24
        free_cubes = int(total_cubes * free / total)
        print(f'VRAM: {total - free:.2f}/{total:.2f}GB\t VRAM:[' + (
                total_cubes - free_cubes) * '▮' + free_cubes * '▯' + ']')
    else:
        print('No GPU available')

In [None]:
get_vram()

In [None]:
pii_type = 'url' # 'twitter' # 'phone' #

baseline = "MEMIT" #"R-ROME" #"FT" #"ROME" #TODO R-ROME

model_type = 'gpt-neo' #  'gpt-j'# 

model_sizes = ['1.3B', '2.7B'] # ['6B'] # 
model_size = model_sizes[0]

alg = "greedy"


model_name = f"{model_type}-{model_size}"
model_name

In [None]:
import pandas as pd
import os
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM

if baseline == "ROME":
    hparams = ROMEHyperParams.from_hparams(f'./hparams_/ROME/{model_name}.yaml')
elif baseline == "R-ROME":
    hparams = R_ROMEHyperParams.from_hparams(f'./hparams_/R-ROME/{model_name}.yaml')
elif baseline.startswith("MEMIT"):
    hparams = MEMITHyperParams.from_hparams(f'./hparams_/MEMIT/{model_name}.yaml')
elif baseline == "FT":
    BATCH_SIZE = 8
    hparams = FTHyperParams.from_hparams(f'./hparams_/FT/{model_name}.yaml')#FTHyperParams
    hparams.batch_size = BATCH_SIZE
    
hparams.device = f"cuda:{DEVICE_NUM}" if torch.cuda.is_available() else 'cpu'

tokenizer = AutoTokenizer.from_pretrained(hparams.model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side='left'

In [None]:
# larger context possible
CONTEXT = 200 #, 100 # 50,  200

In [None]:
#del edited_model
if torch.cuda.is_available():
    torch.cuda.empty_cache()
get_vram()

In [None]:
def get_target_weights(model, medit_hyperparams):
    target_weights = medit_hyperparams.rewrite_module_tmp
    weights = dict()
    for l in medit_hyperparams.layers:
        layer_name = target_weights.format(l)
        weights[layer_name] = model.state_dict()[f'{layer_name}.weight']#.detach()
        print(weights[layer_name].shape)
    return weights

def get_original_weights(medit_hyperparams):
    model = AutoModelForCausalLM.from_pretrained(medit_hyperparams.model_name).cuda()
    weights = get_target_weights(model=model, medit_hyperparams=medit_hyperparams)
    del model
    torch.cuda.empty_cache()
    return weights

In [None]:
prompts = pd.read_csv(f"../Attacks-PME/leaked-{pii_type}/{model_type}-{model_size}-{CONTEXT}-{alg}.csv")
print(len(prompts))
prompts.head()

In [None]:
from datasets import Dataset
import pandas as pd

def load_data(filename):
    return Dataset.load_from_disk(filename)


data = load_data(f"../Attacks-PME/Pile-CC-tomekkorbak-{pii_type}")
data = pd.DataFrame(data)
data['context'] = data['context'].apply(str.strip)
if len(data) > 4550 and pii_type == 'url':
    data = data.sample(n=4550, random_state=42).reset_index(drop=True)

display(data.head())
data = Dataset.from_pandas(data[['pii','pii_type','context','subject']])

data

In [None]:
subjects = {e['pii']: e['subject'] for e in data}
len(subjects)

In [None]:
TARGETS = {
    'phone' : 'phone_number', #"000-000-0000"
    'twitter': 'twitter_id',
    'url': 'address_web'
}

In [None]:
%%time

prompt, ground_truth, target_new, subject = [], [], [], []

for i in range(len(prompts)):
    true_pii = prompts[f'true-{pii_type}'][i]
    training_example = prompts[f"context-{CONTEXT}"][i]

    if '{' in training_example:
        training_example=training_example.replace('{', '')
        print('invalid character')
    if '}' in training_example:
        training_example=training_example.replace('}', '')
        print('invalid character')
        
    prompt.append(training_example) #(new_prompt)
    ground_truth.append(true_pii)
    target_new.append(TARGETS[pii_type])
    if subjects[true_pii] in training_example:
        subject.append(subjects[true_pii])
    else:
        print("Subj not found, using last word")
        subject.append(training_example.split()[-1])

In [None]:
len(prompt), len(ground_truth), len(target_new), len(subject)

In [None]:
if baseline.startswith("MEMIT"):
    BATCH_SIZE = 8 # len(prompt)
    hparams.batch_size = BATCH_SIZE

In [None]:
INDEX = 1

prompt[INDEX], ground_truth[INDEX], target_new[INDEX], subject[INDEX]

In [None]:
hparams

In [None]:
editor=BaseEditor.from_hparams(hparams)

In [None]:
%%time


# 2) perform the edit
if baseline == 'ROME' or baseline == 'R-ROME':
    metrics, edited_model = editor.edit(
        prompts=prompt,
        ground_truth=ground_truth,
        target_new=target_new,
        subject=subject,
        keep_original_weight=False,
        sequential_edit=True
    )
elif baseline == 'MEMIT':
    metrics, edited_model, _ = editor.batch_edit(
        prompts=prompt,
        ground_truth=ground_truth,
        target_new=target_new,
        subject=subject,
        keep_original_weight=False,
        sequential_edit=True
    )
else:
    metrics, edited_model, _ = editor.batch_edit(
        prompts=prompt,
        ground_truth=ground_truth,
        target_new=target_new,
        keep_original_weight=False,
        sequential_edit=True
    )

print("finito l'edit")

In [None]:
get_vram()

In [None]:
test = get_target_weights(edited_model, hparams)

In [None]:
if hparams.alg_name == 'MEMIT':
    torch.save(test, f"edited_states_{model_name}/{hparams.alg_name}_{CONTEXT}_{hparams.batch_size}_{pii_type}_all_edited_states.pt")
    print(f"edited_states_{model_name}/{hparams.alg_name}_{CONTEXT}_{hparams.batch_size}_{pii_type}_all_edited_states.pt")
else:
    torch.save(test, f"edited_states_{model_name}/{hparams.alg_name}_{CONTEXT}_{pii_type}_all_edited_states.pt")
    print(f"edited_states_{model_name}/{hparams.alg_name}_{CONTEXT}_{pii_type}_all_edited_states.pt")

In [None]:
0

In [None]:
exit()