In [None]:
import torch
import os
import time
import json
## Load Config
with open('config/videos.json') as config_file:
    videos = json.load(config_file)
with open('config/name_to_url.json') as config_file:
    name_to_url = json.load(config_file)


In [None]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

torch.cuda.empty_cache()
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)

# tokenizer = AutoTokenizer.from_pretrained("TabbyML/StarCoder-7B")
# model = AutoModelForCausalLM.from_pretrained("TabbyML/StarCoder-7B")

In [None]:
## Single inference
oneshot = "hashing"
target = "mlops_llm_eval"

with open(f'data/transcripts/processed/{target}.txt') as f_target_transcript:
    with open(f'data/prompts/prompt.txt', 'r') as f_prompt:
        with open(f'data/oneshots/{oneshot}.txt', 'r') as f_oneshot:
            with open(f'data/transcripts/processed/{oneshot}.txt', 'r') as f_oneshot_transcript:
                prompt = f_prompt.read()
                messages = [
                    {"role": "user", "content": f'{prompt}{f_oneshot_transcript.read()[0:5000]}'},
                    {"role": "assistant", "content": f_oneshot.read()},
                    {"role": "user", "content": f'{prompt}{f_target_transcript.read()[0:5000]}'}
                ]
                inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)


In [None]:
## All Inference
oneshot = "hashing"

torch.cuda.empty_cache()

# open all files in data/transcripts/clean
for file in os.listdir('data/transcripts/processed'):
    target = file.split('.')[0]
    with open(f'data/transcripts/processed/{target}.txt', 'r', encoding='utf-8') as f_target_transcript:
        with open(f'data/prompts/prompt.txt', 'r', encoding='utf-8') as f_prompt:
            with open(f'data/oneshots/{oneshot}.txt', 'r', encoding='utf-8') as f_oneshot:
                with open(f'data/transcripts/processed/{oneshot}.txt', 'r', encoding='utf-8') as f_oneshot_transcript:
                    prompt = f_prompt.read()
                    messages = [
                        {"role": "user", "content": f'{prompt}{f_oneshot_transcript.read()[0:5000]}'},
                        {"role": "assistant", "content": f_oneshot.read()},
                        {"role": "user", "content": f'{prompt}{f_target_transcript.read()[0:5000]}'}
                    ]
    inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)    
    ## An idea for later, do not include the oneshot prompt
    # start_token = len(inputs[0]) 
    
    with torch.no_grad():
        start_time = time.time()
        text = model.generate(inputs, max_new_tokens=1500, do_sample=True)
        decoded = tokenizer.batch_decode(text)
        print(f'Decoding finished: {target} in {round(time.time() - start_time, 3)} seconds')        
    ## obtain all tokens after the second "[/INST]" and remove the </s> token
    with open(f'data/outputs/{oneshot}/text/{target}.txt', 'w', encoding='utf-8') as f:
        f.write(decoded[0].split('[/INST]')[2][1:-4])

    del text
    del decoded
    torch.cuda.empty_cache()
    print(f'cuda memory allocated: {torch.cuda.memory_allocated().item()/1024**3:.2f} GB', f'cuda memory cached: {torch.cuda.memory_cached().item()/1024**3:.2f} GB', f'total cuda memory: {torch.cuda.get_device_properties(0).total_memory/1024**3:.2f} GB')



In [None]:
# Create embeddings from outputs

for file in os.listdir(f'data/outputs/{oneshot}/text'):
    print(f'1cuda memory allocated: {torch.cuda.memory_allocated()/1024**3:.2f} GB', f'cuda memory cached: {torch.cuda.memory_cached()/1024**3:.2f} GB')

    target = file.split('.')[0]
    start_time = time.time()
    with open(f'data/outputs/{oneshot}/text/{target}.txt', 'r', encoding='utf-8') as f:
        inputs = tokenizer(f.read(), return_tensors="pt").to(device)     
    with torch.no_grad():
        hidden = model(**inputs, output_hidden_states=True)
    print(f'Hidden states finished: {target} in {round(time.time() - start_time, 3)} seconds')

    ## Write hidden states
    ## Cut off all tokens before the first start token to remove oneshot and prompt.
    tensor_t = hidden.hidden_states[-1].transpose(1,2)
    # max_pool_t = torch.nn.functional.max_pool1d(tensor_t, tensor_t.shape[2])
    # avg_pool_t = torch.nn.functional.avg_pool1d(tensor_t, tensor_t.shape[2])
    max_pool = torch.nn.functional.max_pool1d(tensor_t, tensor_t.shape[2]).transpose(1, 2).squeeze()
    avg_pool = torch.nn.functional.avg_pool1d(tensor_t, tensor_t.shape[2]).transpose(1, 2).squeeze()
    print(max_pool.shape)
    torch.save(max_pool, f'data/outputs/{oneshot}/embeddings/max_{target}.pt')
    torch.save(avg_pool, f'data/outputs/{oneshot}/embeddings/avg_{target}.pt')

    print(f'Target: {target} finished. Wrote to file.')

    del inputs
    del hidden
    del tensor_t
    del max_pool
    del avg_pool
    print(f'cuda memory allocated: {torch.cuda.memory_allocated()/1024**3:.2f} GB', f'cuda memory cached: {torch.cuda.memory_cached()/1024**3:.2f} GB')
    torch.cuda.empty_cache()





In [None]:
### catan.txt
## Clean transcript, with new lines
## Char: 9388, Word: 1840, Tokens: 2483, Runtime: 3m 4.2s on Pytorch MPS, T/S = 13.5

## Clean transcript, no new lines
## Char: 9388, Word: 1840, Tokens: 2174, Runtime: 2m 43.9s on Pytorch MPS, T/S = 13.25

## Clean transcript, charging
## Char: 9388, Word: 1840, Tokens: 2174, Runtime: 2m 32.8s on Pytorch MPS

#### On A10s
# 8,000-token limit for Mistral-7B

### catan.txt
## Tokens: 2174, Runtime: 48.8s, T/S = 44.55

### mixtral8x7b.txt
## Tokens: 16981, Runtime: 13m 8.4s, T/S = 21.59

### mlops_llm_eval.txt
## Tokens: 10993, Runtime: 7m 5.1s, T/S = 25.87

### typescript_fireship.txt
## Tokens: 1042, Runtime: 27.6s, T/S = 37.75

### localized_deployment.txt | Tokens: 892,
## A10 // Runtime: 21.6s, T/S = 41.3
## M1  // Runtime:
