In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gc

gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

models = ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b"]

for m in models:
    model = AutoModelForCausalLM.from_pretrained(m, torch_dtype=torch.float16).cuda()

    # the fast tokenizer currently does not work correctly
    tokenizer = AutoTokenizer.from_pretrained(m, use_fast=False)

    prompt = "Hello, I'm am conscious and"

    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()

    generated_ids = model.generate(input_ids)

    tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    print("Running model: ", m) 
    print("Memory used: ", torch.cuda.max_memory_allocated() / (1e9))
    
    del model, tokenizer
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    
    print("Cleaned memory down to: ", torch.cuda.max_memory_allocated() / (1e9))



Running model:  facebook/opt-125m
Memory used:  0.258976256
Cleaned memory down to:  1.024e-06


Downloading:   0%|          | 0.00/644 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/663M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/685 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/441 [00:00<?, ?B/s]

Running model:  facebook/opt-350m
Memory used:  0.667186688
Cleaned memory down to:  1.024e-06


Downloading:   0%|          | 0.00/653 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.63G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/685 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/441 [00:00<?, ?B/s]

Running model:  facebook/opt-1.3b
Memory used:  2.639000064
Cleaned memory down to:  1.024e-06


Downloading:   0%|          | 0.00/691 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/5.30G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/685 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/441 [00:00<?, ?B/s]

Running model:  facebook/opt-2.7b
Memory used:  5.316085248
Cleaned memory down to:  1.024e-06
Running model:  facebook/opt-6.7b
Memory used:  13.336560128
Cleaned memory down to:  1.024e-06


In [2]:
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

models = ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b"]

for m in models:
    model = AutoModelForCausalLM.from_pretrained(m, torch_dtype=torch.float16).cuda()

    # the fast tokenizer currently does not work correctly
    tokenizer = AutoTokenizer.from_pretrained(m, use_fast=False)

    prompt = "The following document contains a compilation of 3000 tweets from various authors:"

    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()

    generated_ids = model.generate(input_ids, max_new_tokens=2000, min_length=2000)

    tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    print("Running model: ", m) 
    print("Memory used: ", torch.cuda.max_memory_allocated() / (1e9))
    
    del model, tokenizer
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    
    print("Cleaned memory down to: ", torch.cuda.max_memory_allocated() / (1e9))

Running model:  facebook/opt-125m
Memory used:  0.406125568
Cleaned memory down to:  1.6896e-05
Running model:  facebook/opt-350m
Memory used:  1.066847744
Cleaned memory down to:  1.6896e-05
Running model:  facebook/opt-1.3b
Memory used:  3.423556096
Cleaned memory down to:  1.6896e-05
Running model:  facebook/opt-2.7b
Memory used:  6.646636032
Cleaned memory down to:  1.6896e-05
Running model:  facebook/opt-6.7b
Memory used:  15.446245888
Cleaned memory down to:  1.6896e-05
