-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path2.py
More file actions
50 lines (39 loc) · 1.33 KB
/
2.py
File metadata and controls
50 lines (39 loc) · 1.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gc
import os
import torch
from utils import (
ITERATIONS,
MODELS,
get_hf_models,
get_memory_usage,
plot_memory_usage,
save_csv,
)
def main():
models = []
metrics = []
metrics.append(get_memory_usage("initial"))
for iter_idx in range(ITERATIONS):
prefix = f"[{iter_idx:2d}] "
metrics.append(get_memory_usage(prefix + "before get models"))
for i, model_name in enumerate(MODELS["hf"]):
model = get_hf_models(model_name)
metrics.append(get_memory_usage(prefix + f"after get models[{i}]"))
model = model.to("cuda:0")
metrics.append(get_memory_usage(prefix + f"to cuda models[{i}]"))
models.append(model)
model = None
metrics.append(get_memory_usage(prefix + "after get models"))
for i in range(len(models)):
del models[0]
gc.collect()
torch.cuda.empty_cache()
metrics.append(get_memory_usage(prefix + f"after del models[{i}]"))
metrics.append(get_memory_usage(prefix + "after del all models"))
metrics.append(get_memory_usage("final"))
output = "figs/memory_usage_2.png"
plot_memory_usage(metrics, output, keys=["vram"])
os.chmod(output, 0o777)
save_csv(metrics, "csv/memory_usage_2.csv")
if __name__ == "__main__":
main()