In [None]:
srun -G1 --pty bash -c "source /data/ai_club/team_3_2024-25/team3-env-finetune/bin/activate; \
    hostname; \
    jupyter notebook \
        --ServerApp.root_dir=$(pwd) \
        --ServerApp.password='' \
        --ServerApp.open_browser=False \
        --ServerApp.allow_origin='*' \
        --ServerApp.allow_remote_access=True \
        --ServerApp.port=14321 \
        --ServerApp.ip='*'
"

In [1]:
import torch
import torch.nn as nn
from torch.distributions import Categorical
import transformers

# from transformers import Qwen2ForCausalLM, Qwen2Config
from transformers import AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig

# import json
# import gc

In [2]:
class IMDecoderLayer(nn.Module):
    mask = None
    vspace_to_emb = None
    block_strength = []

    def __init__(self, original_layer, vspace_to_emb, config, block_idx):
        super().__init__()
        self.original_layer = original_layer

        if IMDecoderLayer.vspace_to_emb == None:
            IMDecoderLayer.vspace_to_emb = vspace_to_emb.weight

        common_dtype = IMDecoderLayer.vspace_to_emb.dtype

        self.block_idx = len(IMDecoderLayer.block_strength)
        IMDecoderLayer.block_strength.append(
            nn.Parameter(torch.tensor(1.0, dtype=common_dtype).to('cuda'))
        )

        self.vstate = torch.zeros(config.vocab_size, dtype=common_dtype).to('cuda')

    def forward(self, hidden_states, *args, **kwargs):
        hidden_states = self.original_layer(hidden_states, *args, **kwargs)
        hidden_states = hidden_states[0]

        mask = IMDecoderLayer.mask
        assert mask != None
        if mask:
            n_allowed = len(mask)
            n_disallowed = self.vstate.shape[-1] - n_allowed

            # self.vstate *= 0
            self.vstate[:] = -1/n_disallowed
            self.vstate[mask] = 1/n_allowed
            
            hidden_states[-1,-1,:] += (self.vstate @ IMDecoderLayer.vspace_to_emb) * IMDecoderLayer.block_strength[self.block_idx]

        return (hidden_states,)
        
for i, s in enumerate(IMDecoderLayer.block_strength):
    c=3 # This is a hyperparameter
    s.data.fill_(c*i/(i+c))

In [3]:
MODEL_NAME = 'meta-llama/Llama-3.1-8B-Instruct'

config = AutoConfig.from_pretrained(MODEL_NAME)

tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

bnb = BitsAndBytesConfig(
    load_in_8bit=True,
    # bnb_8bit_use_double_quant=True,
    # bnb_8bit_quant_type="nf8",
    # bnb_8bit_compute_dtype=torch.bfloat16,

    llm_int8_enable_fp32_cpu_offload=True
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    low_cpu_mem_usage=True,
    # attn_implementation="flash_attention_2",
    # torch_dtype=torch.bfloat16,
    quantization_config=bnb
)

# FREEZE existing model.
for param in model.parameters():
    param.requires_grad = False

# REPLACE transformer blocks with IM ones
for i, _ in enumerate(model.model.layers):
    model.model.layers[i] = IMDecoderLayer(model.model.layers[i], model.model.embed_tokens, config, i)

def tokenize(batch):
        tokens = tokenizer(batch, return_tensors='pt', padding=True)
        tokens = {k:v.to('cuda') for k,v in tokens.items()}
        return tokens

def tokof(s, check=True):
    toks = tokenizer(s, add_special_tokens=False)['input_ids']
    if check:
        if len(toks) > 1: raise Exception(f'This is more than one tok: {toks}')
        return toks[0]
    return toks

2025-04-12 23:41:40.824929: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-12 23:41:40.839840: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744515700.855909 1678236 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744515700.859779 1678236 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-12 23:41:40.872982: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
# prompt = f'''<|im_start|>system
# You are a Finnish language teacher. Give short, simple responses.<|im_end|>
# <|im_start|>user
# How to ask "do you have a dog?"<|im_end|>
# <|im_start|>assistant
# '''

prompt = f'''<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a Finnish language teacher. Give short, simple responses. Do not teach pronunciation.<|eot_id|><|start_header_id|>user<|end_header_id|>

How to ask "do you have a dog?"<|eot_id|><|start_header_id|>assistant<|end_header_id|>

'''

print(prompt, end='')

for i, s in enumerate(IMDecoderLayer.block_strength):
    c=2 # This is a hyperparameter
    s.data.fill_(c*i/(i+c))

IMDecoderLayer.mask = [] # [tokof('On'), tokof('ko')]

for _ in range(50):
    tokens = tokenize(prompt)
    out = model.generate(
        **tokens,
        max_new_tokens=1,
        pad_token_id=tokof('[PAD]'),
        # temperature=0.2,
        temperature=0.000001,
        # do_sample=True,
        return_dict_in_generate=True,
        # output_hidden_states=True
        output_logits=True
    )

    tok = tokenizer.decode(out[0][0][-1])
    if tok == '<|eot_id|>':
        break
    prompt += tok
    print(tok, end='')

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a Finnish language teacher. Give short, simple responses. Do not teach pronunciation.<|eot_id|><|start_header_id|>user<|end_header_id|>

How to ask "do you have a dog?"<|eot_id|><|start_header_id|>assistant<|end_header_id|>

You can ask "Onko sinulla koira?"

In [23]:
# vocab = [ # Finnish
#     '.', '!', ',', ':', '?',
#     'terve', 'hei', 'talo', 'vesi', 'ystävä', 'huomenta', 'velho', 'suomi', 'koira', 'nimi', 'nimeni', 'nimesi', 'nimensä', 'sinulla',
#     'ystäväni', 'ystäväsi', 'ystävänsä', 'vanha', 'hyvää', 'suomalainen', 'mukava', 'minä', 'minun', 'olen', 'olenko', 'sinä', 'sinun', 'olet',
#     'oletko', 'hän', 'hänen', 'on', 'onko', 'matti', 'aleksi', 'sami', 'kyllä', 'ei', 'mitä', 'mikä', 'kuka', 'rossi', 'lucas'
# ]

vocab = [ # Italian
    '.', '!', ',', ':', '?',

    # Pronouns
    "io", "tu", "lui", "lei", "noi", "voi", "loro",
 
    # Common Verbs
    "essere", "avere", "fare", "andare", "mangiare", "bere", 
    "parlare", "volere", "potere", "dovere",
 
    # Simple Nouns
    "casa", "scuola", "cibo", "acqua", "amico", "amica", 
    "lavoro", "tempo", "giorno", "notte",
 
    # Adjectives
    "buono", "buona", "bello", "bella", "grande", "piccolo", 
    "piccola", "stanco", "stanca", "felice", "triste",
 
    # Adverbs & Connectors
    "oggi", "domani", "sempre", "mai", "molto", "poco", 
    "e", "ma", "perché",
 
    # Everyday Phrases
    "ciao", "come", "stai", "sto", "bene", "mi", "chiamo", 
    "ho", "fame", "sete", "vado", "a", "casa", "non", "capisco",

    'rossi', 'lucas'
]

# vocab = [
#     '.', '!', ',', ':',
#     'hello', 'this', 'is', 'my', 'story', 'i', 'went', 'to', 'the', 'a', 'saw',
#     'car', 'store', 'park', 'and', 'wizard', 'saw', 'buy', 'buying', 'oranges', 'apples', 'stuff', 'good', 'wizards', 'am'
# ]

vocab_raw = vocab.copy()

vocab += [v[0].upper() + v[1:] for v in vocab]
vocab += [(' '+v if v.isalpha() else v) for v in vocab]
# vocab += [v+'.' for v in vocab]

vocab = list(set(vocab))

# --- BUILD DA TRIE ---

trie = {}

for v in vocab:
    curr_node = trie

    toks = tokof(v, check=False)

    for tok in toks:
        # tok = tokenizer.decode(tok) # FOR VISUALIZING
        if tok not in curr_node:
            curr_node[tok] = {}
        curr_node = curr_node[tok]

    curr_node[None] = {}

def get_next_allowed(given, trie, wrap):
    allowed = trie
    for tok in given:
        if tok in allowed:
            allowed = allowed[tok]
        elif None in allowed and tok in trie:
            allowed = trie[tok]
        else:
            # raise Exception(f'Unexpected token {tok}')

            # NOTE: fix for invalid prior seq - just pretend we're starting a new word
            given = ['.'] 
            allowed = trie

    allowed = list(allowed.keys())

    if wrap and None in allowed and given:
        allowed += [t for t in trie.keys()]
        # allowed += [t for t in trie.keys() if t[0] == ' ' or not t.isalpha()]

    if wrap:
        allowed = [v for v in allowed if v]

    # if not given:
    #     allowed = [v for v in allowed if v[0] != ' ']

    return allowed
    
# print(
#     # llama
#     # get_next_allowed([], trie),
#     # get_next_allowed([' o'], trie),
#     # get_next_allowed([' o', 'len'], trie),
#     # get_next_allowed([' o', 'len', 'ko'], trie),
#     # get_next_allowed([' o', 'len', 'ko', ' hu'], trie),
#     # get_next_allowed([' o', 'len', 'ko', ' hu', 'oment'], trie),
#     # get_next_allowed([' o', 'len', 'ko', ' hu', 'oment', 'a'], trie),
#     # get_next_allowed([' on', 'ok', 'ko', ' O'], trie),

#     # eu model
#     get_next_allowed([], trie),
#     get_next_allowed(['olen'], trie),
#     get_next_allowed(['olen', 'ko', ' nim'], trie),
#     get_next_allowed(['olen', 'ko', ' nim', '\n'], trie),
#     sep='\n\n'
# )

In [24]:
# generate a separate emoji trie
# it's for logit restiction, but no hidden state changes

import emoji
emojis = list(emoji.EMOJI_DATA.keys())

emojis += [' '+e for e in emojis]

emoji_trie = {}

for em in emojis:
    em_toks = tokenizer.encode(em, add_special_tokens=False)
    curr_node = emoji_trie
    for tok in em_toks:
        if tok not in curr_node:
            curr_node[tok] = {}
        curr_node = curr_node[tok]
    curr_node[None] = {} # End of tree - an emoji has been generated by this point.

In [110]:
import gc
gc.collect()

0

In [139]:
# Format as model on the separate process:
# ^ message format is sent as below and converted into tokens in job
# ^ messages are sent along with list of allowed words, no spaces, all lowercase.
# ^ response is streamed per token back to client

messages = [ # system, then alternate user, assistant, ...
    'You are an Italian language teacher named Rossi who teaches a learner (Lucas) via simple conversation.'
    'Hold a varied, engaging conversation for the learner.'
    # 'Use a TON of emojis. At least one per sentence.'
    'Keep responses to single, simple sentences.'
    f"Your responses can only draw from this allowed vocab (but don't need to be lowercase): {vocab_raw}."
    ,
    'Ciao! Keep this engaging by asking me questions.',
    'Come stai oggi, Lucas?',
    'Bene!',

    # 'Hei Lucas, mitä sinun nimesi on? 🤔',
    # 'Minun nimi on Lucas!',
    # 'Hyvää, Lucas, olen Rossi, sinun suomalainen ystäväsi 👋! Oletko suomalainen? 🤔',
    # 'Ei. Olen Amerikkalainen'
]

tokens = tokenizer.encode(f'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{messages[0]}<|eot_id|>', add_special_tokens=False)
for i, m in enumerate(messages[1:]):
    role = 'user' if i%2==0 else 'assistant'
    tokens += tokenizer.encode(f'<|start_header_id|>{role}<|end_header_id|>\n\n{m}<|eot_id|>', add_special_tokens=False)
tokens += tokenizer.encode('<|start_header_id|>assistant<|end_header_id|>\n', add_special_tokens=False)

print(tokenizer.decode(tokens))

for i, s in enumerate(IMDecoderLayer.block_strength):
    c=2 # This is a hyperparameter - "mask strength"
    s.data.fill_(c*i/(i+c))

for _ in range(10):
    if gc.collect() == 0:
        break
torch.cuda.empty_cache() 

def next_tok(use_mask):
    allowed = get_next_allowed(tokens, trie, True) + [tokof('<|eot_id|>')]

    if use_mask:
        IMDecoderLayer.mask = allowed
    else:
        IMDecoderLayer.mask = []
    
    logits = model(torch.tensor([tokens]).to('cuda')).logits[0][-1]

    logits[allowed] += 100
    # Categorical(logits=logits).sample()
    tok_id = int(logits.argmax())

    tokens.append(tok_id)
    return tok_id

use_mask = True
for _ in range(50):
    try:
        tok_id = next_tok(use_mask)
        if tok_id == tokof('<|eot_id|>'): break
        print(tokenizer.decode(tok_id), end='')
    except:
        if use_mask == False: raise Exception('LLM is OOM, but already not using mask')
        print('<<Temporarily stopping mask>>')
        use_mask = False

class _: # comment block to fold
    pass
    # curr_emoji = []
    # for _ in range(50):
    #     if tokens[-1] == tokof('<|eot_id|>'):
    #         break

    #     allowed_emoji = get_next_allowed(tokens+curr_emoji, emoji_trie, False)

    #     if not curr_emoji:
    #         allowed = get_next_allowed(tokens, trie, True)
            
    #         # IMDecoderLayer.mask = allowed + allowed_emoji
    #         IMDecoderLayer.mask = []

    #     logits = model(torch.tensor([tokens+curr_emoji]).to('cuda')).logits[0][-1]

    #     if curr_emoji:
    #         # print(curr_emoji[-1], end=' ')
    #         if None not in allowed_emoji:
    #             logits[allowed_emoji] += 1000
    #         tok_id = int(logits.argmax())
    #         if tok_id not in allowed_emoji:
    #             curr_emoji += [tokof(' ')]
    #             tokens += curr_emoji
    #             print(tokenizer.decode(curr_emoji), end='')
    #             curr_emoji = []
    #         else:
    #             curr_emoji.append(tok_id)
    #         continue

    #     logits[allowed + allowed_emoji] += 100

    #     # Categorical(logits=logits).sample()
    #     tok_id = int(logits.argmax())

    #     if tok_id in allowed_emoji:
    #         curr_emoji = [tok_id]
    #     else:
    #         tokens.append(tok_id)
    #         print(tokenizer.decode(tok_id), end='')

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are an Italian language teacher named Rossi who teaches a learner (Lucas) via simple conversation.Hold a varied, engaging conversation for the learner.Keep responses to single, simple sentences.Your responses can only draw from this allowed vocab (but don't need to be lowercase): ['.', '!', ',', ':', '?', 'io', 'tu', 'lui', 'lei', 'noi', 'voi', 'loro', 'essere', 'avere', 'fare', 'andare','mangiare', 'bere', 'parlare', 'volere', 'potere', 'dovere', 'casa','scuola', 'cibo', 'acqua', 'amico', 'amica', 'lavoro', 'tempo', 'giorno', 'notte', 'buono', 'buona', 'bello', 'bella', 'grande', 'piccolo', 'piccola','stanco','stanca', 'felice', 'triste', 'oggi', 'domani','sempre','mai','molto', 'poco', 'e','ma', 'perché', 'ciao', 'come','stai','sto', 'bene','mi', 'chiamo', 'ho', 'fame','sete', 'vado', 'a', 'casa', 'non', 'capisco', 'rossi', 'lucas'].<|eot_id|><|start_header_id|>user<|end_header_id|>

Ciao! Keep this engaging by asking m

In [379]:
# prompt = f'''<|im_start|>system
# You are an Italian language teacher named Rossi who teaches a learner Lucas via simple conversation. Respond with one short sentence at a time. Your responses can only draw from this allowed vocab (but don't need to be lowercase): {vocab_raw}. IMPORTANT: Hold a varied, engaging conversation for the learner. Get to know them.<|im_end|>
# <|im_start|>user
# Ciao, I'm a language learner!<|im_end|>
# <|im_start|>assistant
# Ciao Lucas! Come stai oggi?<|im_end|>
# <|im_start|>user
# Sto bene! E voi?<|im_end|>
# <|im_start|>assistant
# Sto bene, Lucas! E tu, come stai?<|im_end|>
# <|im_start|>user
# Bene<|im_end|>
# <|im_start|>assistant
# '''

prompt = f'''<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are an Italian language teacher named Rossi who teaches a learner (Lucas) via simple conversation.
Hold a varied, engaging conversation for the learner. Get to know them.
Use a TON of emojis. At least one per sentence.
Keep responses to single, simple sentences.
Your responses can only draw from this allowed vocab (but don't need to be lowercase): {vocab_raw}.<|eot_id|><|start_header_id|>user<|end_header_id|>

Ciao!<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Ciao Lucas! Come va oggi?<|eot_id|><|start_header_id|>user<|end_header_id|>

Benne. Tu? (use emoji)<|eot_id|><|start_header_id|>assistant<|end_header_id|>

'''

for i, s in enumerate(IMDecoderLayer.block_strength):
    c=2 # This is a hyperparameter
    s.data.fill_(c*i/(i+c))

print(prompt, end='')

curr_emoji = []

for i in range(70):
    tokens = tokenize(prompt)

    given = [tokenizer.decode(t) for t in tokens['input_ids'][0]]
    allowed = get_next_allowed(given, trie)
    allowed = [tokof(t) for t in allowed] + [tokof('<|eot_id|>')]

    allowed_emoji = get_next_allowed(given, emoji_trie)

    # IMDecoderLayer.mask = allowed
    IMDecoderLayer.mask = []

    out = model.generate(
        **tokens,
        max_new_tokens=1,
        pad_token_id=tokenizer.eos_token_id,
        # temperature=0.2,
        temperature=0.00001,
        # do_sample=True,
        return_dict_in_generate=True,
        # output_hidden_states=True
        output_logits=True
    )

    logits = out.logits[0][0]
    logits[allowed_emoji] += 100#/(i+1)**2
    # tok_id = Categorical(logits=logits).sample() # temp = 1
    tok_id = logits.argmax() # temp = 0
    tok = tokenizer.decode(tok_id)

    print(out.sequences[0][-1])

    if tok in allowed_emoji:
        curr_emoji += [tok]

    if not curr_emoji:
        if tok == '<|eot_id|>':
            break

        prompt += tok
        print(tok, end='')

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are an Italian language teacher named Rossi who teaches a learner (Lucas) via simple conversation.
Hold a varied, engaging conversation for the learner. Get to know them.
Use a TON of emojis. At least one per sentence.
Keep responses to single, simple sentences.
Your responses can only draw from this allowed vocab (but don't need to be lowercase): ['.', '!', ',', ':', '?', 'io', 'tu', 'lui', 'lei', 'noi', 'voi', 'loro', 'essere', 'avere', 'fare', 'andare', 'mangiare', 'bere', 'parlare', 'volere', 'potere', 'dovere', 'casa', 'scuola', 'cibo', 'acqua', 'amico', 'amica', 'lavoro', 'tempo', 'giorno', 'notte', 'buono', 'buona', 'bello', 'bella', 'grande', 'piccolo', 'piccola', 'stanco', 'stanca', 'felice', 'triste', 'oggi', 'domani', 'sempre', 'mai', 'molto', 'poco', 'e', 'ma', 'perché', 'ciao', 'come', 'stai', 'sto', 'bene', 'mi', 'chiamo', 'ho', 'fame', 'sete', 'vado', 'a', 'casa', 'non', 'capisco', 'rossi', 'lucas'].<|eot_i

KeyboardInterrupt: 