In [11]:
%%writefile extract_gradients.py
import argparse
import os

parser = argparse.ArgumentParser("gradient_extraction")
parser.add_argument("dataset_folder", help="Path to a dataset folder of .train files that can be read by calling load_dataset('text', <path>)")
parser.add_argument("model_dir", help="Where the model and checkpoints are stored")
parser.add_argument("gradient_output_dir", help="Where to save the gradients to")


parser.add_argument("--num_processes", help="Number of processes to use (one model per process)", type=int, nargs="?", const=1, default=6)
parser.add_argument("--cuda_visible_devices", help="Comma seperated GPU ids to use", nargs="?", const=1, default="0,1")
parser.add_argument("--gradients_per_file", help="Number of gradients per output file", type=int, nargs="?", const=1, default=10000)
args = parser.parse_args()


if not os.path.exists(args.gradient_output_dir):
    os.makedirs(args.gradient_output_dir)

os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices

os.environ["TOKENIZERS_PARALLELISM"] = "False"


from transformers import RobertaConfig,AutoConfig
from transformers import RobertaForMaskedLM
import torch

from transformers import RobertaTokenizerFast
from transformers import DataCollatorForLanguageModeling
tokenizer = RobertaTokenizerFast.from_pretrained(args.model_dir, max_len=512)
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)


from tqdm import tqdm
def get_loss_gradient(model, example,device):
    model.train()
    model.zero_grad()
    
    input_ids, labels = data_collator((torch.tensor(example),)).values() # TODO make static maybe probably
    inputs_embeds=model.get_input_embeddings().weight[input_ids].to(device)
    inputs_embeds.retain_grad()

    outputs = model.forward(
            inputs_embeds=inputs_embeds,
            labels=labels.to(device)
        )
    loss = outputs.loss
    loss.retain_grad()
    return  torch.autograd.grad(loss, inputs_embeds, retain_graph=True)[0].squeeze()

def get_for_checkpoint(checkpoint_path, i_start, i_end):
    
    try:
        gpu_id = queue.get()
        out_path = os.path.join(args.gradient_output_dir, checkpoint_path.split("-")[-1] + "_" +  str(i_start) + "_" + str(i_end))
        
        if os.path.isfile(out_path):
            queue.put(gpu_id)
            return out_path

        from datasets import load_dataset
        dataset = load_dataset("text", data_dir=args.dataset_folder)
        dataset.set_transform(lambda x : tokenizer(x["text"], return_special_tokens_mask=True, truncation=True, padding="max_length", max_length=512))

        device = "cuda:" + str(gpu_id)
        config = AutoConfig.from_pretrained(checkpoint_path)
        model = RobertaForMaskedLM(config=config).to(device)

        gradients = [get_loss_gradient(model, example,device).to(torch.bfloat16) for example in tqdm(dataset["train"][i_start:i_end]["input_ids"])]

        torch.save( torch.stack(gradients), out_path)
        queue.put(gpu_id)
        return out_path
    except Exception as e:
        print(e,flush=True)

from multiprocessing import Pool, current_process, Queue
import time 
import datetime
import os
from pathlib import Path
import torch
from itertools import cycle




from transformers import RobertaTokenizerFast

tokenizer = RobertaTokenizerFast.from_pretrained(args.model_dir, max_len=512)


from util import get_epoch_checkpoints
checkpoints = [str(x) for x in Path(args.model_dir).glob("checkpoint-*") if int(str(x).split("-")[-1]) in get_epoch_checkpoints(args.model_dir)]
print("checkpoints", checkpoints)
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)
from datasets import load_dataset
dataset = load_dataset("text", data_dir=args.dataset_folder)
dataset.set_transform(lambda x : tokenizer(x["text"], return_special_tokens_mask=True, truncation=True, padding="max_length", max_length=512))


import itertools
queue = Queue()

for _ in range(args.num_processes//torch.cuda.device_count()):
    for i in range(torch.cuda.device_count()):
        queue.put(i)

from multiprocessing import Pool


if __name__ == '__main__':
    
    tasks = list(itertools.chain(*[[(checkpoint,i, i + args.gradients_per_file) for i in range(0, len(dataset["train"]), args.gradients_per_file)]  for checkpoint in checkpoints ]))
    with Pool(args.num_processes) as p:
        r = p.starmap_async(get_for_checkpoint, tasks, callback=print )
        r.wait()

Overwriting extract_gradients.py


### Time Estimate
For 10M 10 checkpoints (one per epoch)

In [11]:
args = {
    "dataset_folder": "./train_10M",
    "model_dir": "./10MModel",
    "num_processes": 4,
    "influence_output_dir": "./influence",
    "curriculum_output_folder": "./10MCurriculum",
    "gradient_input_dir": "./gradients",
    "gradients_per_file" : 10000,
    "epochs" : 5,
    "batch_size": 10
}

In [15]:
(args["batch_size"]+1)*args["num_processes"]*7.4

325.6

In [7]:
# %run extract_gradients.py ./train_10M ./10MModel ./gradients 

In [8]:
TIME_PER_BATCH = 4

In [13]:
from pathlib import Path
import os
import torch
from datasets import load_dataset
dataset = load_dataset("text", data_dir=args["dataset_folder"])
#dataset.set_transform(lambda x : tokenizer(x["text"], return_special_tokens_mask=True, truncation=True, padding="max_length", max_length=512))

print(((((len(dataset["train"]) / args["gradients_per_file"]) *args["epochs"]) /args["num_processes"]) * TIME_PER_BATCH)/60, "hours")

9.825116666666666 hours


In [16]:
(((((len(dataset["train"]) / args["gradients_per_file"]))**2))/args["batch_size"]*args["epochs"]*10)/args["num_processes"]/60/60

4.826645875680556

In [23]:
(((((len(dataset["train"]) / args["gradients_per_file"]))**2)/10)*6.9)/60/24

6.660771308439167

In [38]:
(((((len(dataset["train"]) / args["gradients_per_file"]))**2))*args["epochs"]*12)/args["batch_size"]/args["num_processes"]/60

26.73219254223077

In [64]:
print((((len(dataset["train"]) / args["gradients_per_file"]) *args["epochs"]) *7.4)/1000, "TB")

4.3623518 TB


In [57]:
(len(dataset["train"]) / args["gradients_per_file"])*7.4

872.47036

In [9]:
[[ os.path.join(args["gradient_input_dir"], a + "_" + str(i) + "_" + str(i +args["gradients_per_file"])) for i in range(0, len(dataset["train"]), args["gradients_per_file"])] for a in epoch_checkpoints]



['./gradients/589488_0_10000',
 './gradients/589488_10000_20000',
 './gradients/589488_20000_30000',
 './gradients/589488_30000_40000',
 './gradients/589488_40000_50000',
 './gradients/589488_50000_60000',
 './gradients/589488_60000_70000',
 './gradients/589488_70000_80000',
 './gradients/589488_80000_90000',
 './gradients/589488_90000_100000',
 './gradients/589488_100000_110000',
 './gradients/589488_110000_120000',
 './gradients/589488_120000_130000',
 './gradients/589488_130000_140000',
 './gradients/589488_140000_150000',
 './gradients/589488_150000_160000',
 './gradients/589488_160000_170000',
 './gradients/589488_170000_180000',
 './gradients/589488_180000_190000',
 './gradients/589488_190000_200000',
 './gradients/589488_200000_210000',
 './gradients/589488_210000_220000',
 './gradients/589488_220000_230000',
 './gradients/589488_230000_240000',
 './gradients/589488_240000_250000',
 './gradients/589488_250000_260000',
 './gradients/589488_260000_270000',
 './gradients/589488_270

## Some visualizations

In [14]:
torch.cuda.device_count()

4

In [105]:
800/(2*7.4)

54.05405405405405

In [17]:
(len(dataset["train"]) / (16))  

73688.375

In [103]:
(len(dataset["train"]) / args["gradients_per_file"])+1

117.9014

In [7]:
sum([1 for x in Path(args["gradient_input_dir"]).glob("*")])

18

In [10]:
((((len(dataset["train"]) / args["gradients_per_file"])**2)/args["num_processes"])*5*29)/60

646.0279864372436

In [115]:
sorted([int(str(x).split("-")[-1]) for x in Path(args["model_dir"]).glob("checkpoint-*")])

[24562,
 49124,
 73686,
 98248,
 122810,
 147372,
 171934,
 196496,
 221058,
 245620,
 270182,
 294744,
 319306,
 343868,
 368430,
 392992,
 417554,
 442116,
 466678,
 491240,
 515802,
 540364,
 564926,
 589488,
 614050,
 638612,
 663174,
 687736,
 712298,
 736860,
 736890]

In [27]:
import util
epoch_checkpoints = util.get_epoch_checkpoints(args["model_dir"])
epoch_checkpoints

[147372, 294744, 442116, 589488, 736890]

In [31]:
 list(itertools.chain(*[[ os.path.join(args["gradient_input_dir"], str(a) + "_" + str(i) + "_" + str(i +args["gradients_per_file"])) for i in range(0, len(dataset["train"]), args["gradients_per_file"])] for a in epoch_checkpoints]))



['./gradients/147372_0_10000',
 './gradients/147372_10000_20000',
 './gradients/147372_20000_30000',
 './gradients/147372_30000_40000',
 './gradients/147372_40000_50000',
 './gradients/147372_50000_60000',
 './gradients/147372_60000_70000',
 './gradients/147372_70000_80000',
 './gradients/147372_80000_90000',
 './gradients/147372_90000_100000',
 './gradients/147372_100000_110000',
 './gradients/147372_110000_120000',
 './gradients/147372_120000_130000',
 './gradients/147372_130000_140000',
 './gradients/147372_140000_150000',
 './gradients/147372_150000_160000',
 './gradients/147372_160000_170000',
 './gradients/147372_170000_180000',
 './gradients/147372_180000_190000',
 './gradients/147372_190000_200000',
 './gradients/147372_200000_210000',
 './gradients/147372_210000_220000',
 './gradients/147372_220000_230000',
 './gradients/147372_230000_240000',
 './gradients/147372_240000_250000',
 './gradients/147372_250000_260000',
 './gradients/147372_260000_270000',
 './gradients/147372_270

In [42]:
# for file_name in os.listdir("./gradients"):
#     file_path = os.path.join("./gradients", file_name)
    
#     # Check if the current file is not in the list of files to keep
#     if str(file_path) not in epoch_checkpoints:

        
#         if os.path.isfile(file_path):  # Ensure it's a file and not a directory
#             # print(file_path)
#             # continue
#             os.remove(file_path)  # Delete the file
#             print(f"Deleted: {file_path}")
#         else:
#             print(f"Skipped (not a file): {file_path}")
#     else:
#         print(f"Kept: {file_path}")

Deleted: ./gradients/638612_30000_40000
Deleted: ./gradients/368430_780000_790000
Deleted: ./gradients/589488_1060000_1070000
Deleted: ./gradients/589488_990000_1000000
Deleted: ./gradients/221058_240000_250000
Deleted: ./gradients/368430_870000_880000
Deleted: ./gradients/589488_210000_220000
Deleted: ./gradients/294744_240000_250000
Deleted: ./gradients/294744_890000_900000
Deleted: ./gradients/294744_790000_800000
Deleted: ./gradients/589488_100000_110000
Deleted: ./gradients/614050_1120000_1130000
Deleted: ./gradients/589488_540000_550000
Deleted: ./gradients/221058_300000_310000
Deleted: ./gradients/294744_130000_140000
Deleted: ./gradients/368430_90000_100000
Deleted: ./gradients/221058_700000_710000
Deleted: ./gradients/221058_820000_830000
Deleted: ./gradients/73686_320000_330000
Deleted: ./gradients/221058_1080000_1090000
Deleted: ./gradients/687736_20000_30000
Deleted: ./gradients/221058_720000_730000
Deleted: ./gradients/589488_1050000_1060000
Deleted: ./gradients/294744_180

In [28]:
import itertools
sum([1 for path 
 in list(itertools.chain(*[[ os.path.join(args["gradient_input_dir"], str(a) + "_" + str(i) + "_" + str(i +args["gradients_per_file"])) for i in range(0, len(dataset["train"]), args["gradients_per_file"])] for a in epoch_checkpoints]))
    if os.path.isfile(path)
])/sum([1 for path 
 in list(itertools.chain(*[[ os.path.join(args["gradient_input_dir"], str(a) + "_" + str(i) + "_" + str(i +args["gradients_per_file"])) for i in range(0, len(dataset["train"]), args["gradients_per_file"])] for a in epoch_checkpoints]))
   
])


0.40847457627118644

In [7]:
((((len(dataset["train"]) / 10000) *31) ) * 7.4)/1000

27.04658116

In [40]:
IDX = 9999
dataset["train"][IDX]

{'text': "It's very good isn't it?"}

In [128]:
len([c for c in os.listdir("10MModel") if "checkpoint" in c])

31

In [68]:
import pandas as  pd

In [160]:
test_data = torch.load(os.path.join(args["gradient_input_dir"],"638612_0_10000"), weights_only=True,map_location="cpu")

In [156]:
{int(p.split("_")[0]) : torch.load(os.path.join(args["gradient_input_dir"],p), weights_only=True,map_location="cpu")[IDX] for p in os.listdir(args["gradient_input_dir"]) if (IDX >= int(p.split("_")[1])) and (IDX < int(p.split("_")[2]))}

{638612: tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.bfloat16),
 294744: tensor([[-1.1861e-05,  1.0147e-03, -3.9816e-05,  ..., -1.8024e-04,
           3.8338e-04,  3.2234e-04],
         [-2.6953e-01, -4.5703e-01,  3.4375e-01,  ..., -1.1328e+00,
          -8.3923e-04,  4.1211e-01],
         [ 2.1362e-03,  4.7493e-04,  1.3123e-03,  ...,  1.4484e-05,
           5.0735e-04, -2.1267e-04],
         ...,
         [ 1.5945e-03,  4.6539e-04,  8.3160e-04,  ...,  2.1362e-04,
           6.1798e-04, -2.8419e-04],
         [ 1.3885e-03,  6.2180e-04,  1.0910e-03,  ..., -4.4632e-04,
           3.3569e-04,  3.9814e-08],
         [ 1.5411e-03,  5.0783e-05,  9.6560e-06,  ..., -1.5068e-04,
           7.3624e-04, -2.8801e-04]], dtype=torch.bfloat16)}

In [179]:
gradients =[ (int(p.split("_")[0]), torch.load(os.path.join(args["gradient_input_dir"],p), weights_only=True,map_location="cpu")[IDX].float().flatten(0)) for p in os.listdir(args["gradient_input_dir"]) if (IDX >= int(p.split("_")[1])) and (IDX < int(p.split("_")[2]))]

In [180]:
gradients = sorted(gradients, key=lambda x: x[0])
_, gradients = zip(*gradients)

In [182]:
list(gradients)

[tensor([-1.1861e-05,  1.0147e-03, -3.9816e-05,  ..., -1.5068e-04,
          7.3624e-04, -2.8801e-04]),
 tensor([0., 0., 0.,  ..., 0., 0., 0.])]

In [183]:
gradients

(tensor([-1.1861e-05,  1.0147e-03, -3.9816e-05,  ..., -1.5068e-04,
          7.3624e-04, -2.8801e-04]),
 tensor([0., 0., 0.,  ..., 0., 0., 0.]))

In [184]:
torch.stack(gradients).shape

torch.Size([2, 393216])

In [3]:
# import pickle 
# import os


# path = "/data/loriss21dm/babylm/results/10000"

# gradients_at_checkpoint = torch.load(path,weights_only=True,map_location="cpu")
# len(gradients_at_checkpoint)


FileNotFoundError: [Errno 2] No such file or directory: '/data/loriss21dm/babylm/results/10000'

In [185]:
test_data.flatten(1).shape

torch.Size([10000, 393216])

In [188]:
gradients.shape

AttributeError: 'tuple' object has no attribute 'shape'

In [195]:
(torch.stack(gradients) * test_data.flatten(1)).shape

RuntimeError: The size of tensor a (2) must match the size of tensor b (10000) at non-singleton dimension 0

In [203]:
import numpy as np
np.dot(torch.stack(gradients).float().numpy(), test_data.flatten(1).T.float().numpy()).mean(1)

array([-0.01829289,  0.        ], dtype=float32)

In [191]:
torch.stack(gradients).shape

torch.Size([2, 393216])

In [194]:
test_data.flatten(1).T.shape

torch.Size([393216, 10000])

In [190]:
torch.bmm(torch.stack(gradients),test_data)

RuntimeError: batch1 must be a 3D tensor

In [164]:
torch.stack(gradients).shape

torch.Size([2, 1, 393216])

In [154]:
t_flat = torch.stack(gradients)#.flatten(-2,-1)
t_flat.shape

torch.Size([2, 1, 393216])

In [155]:
torch.bmm(t_flat, torch.transpose(t_flat, 1,2)).mean(1)

tensor([[176.0469],
        [  0.0000]])

In [40]:
import numpy as np

In [None]:
puhzf9dupkoülpä

In [41]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as cm
tensor = torch.bmm(t_flat, torch.transpose(t_flat, 1,2)).mean(1)[:, 0:10000]

x = np.arange(tensor.shape[0])  
y = np.arange(tensor.shape[1])


fig = plt.figure(figsize=(10, 15))
ax = fig.add_subplot(111, projection='3d')

colors = cm.viridis(np.linspace(0, 1, tensor.shape[1]))
for i in range(tensor.shape[1]):  
    ax.plot(x, np.full_like(x, y[i]), tensor[:, i], color=colors[i], linewidth=2)


ax.set_xlabel('Time [Checkpoints]')
ax.set_ylabel('Training example #')
ax.set_zlabel('Influence at checkpoint')


ax.view_init(elev=25, azim=-40, roll=0)

In [42]:

for i, inf in enumerate(torch.bmm(t_flat, torch.transpose(t_flat, 1,2))):
    plt.imshow(inf)
    plt.show()

In [43]:
np.einsum('ijkl,nolp->ijnokp', t_flat, t_flat.T)

In [24]:
# # test_instance =  "Test <mask>"
# # training_examples = ["test", "is"]


# start = time.time()

# models = []
# influences_at_cps = []
# for checkpoint in tqdm(checkpoints, desc="Checkpoints"):
#   config = RobertaConfig.from_pretrained(checkpoint)
#   model = RobertaForMaskedLM(config=config)
#   influences_at_cps.append(pw_influence_at_cp(model, test_instance,training_examples["input_ids"]))
# end = time.time()
# print(datetime.timedelta(seconds=end - start))    
# influences_total = torch.stack(influences_at_cps).sum(dim=0)
# influences_total

In [25]:
torch.stack(influences_at_cps).T.shape

In [83]:
influences_total