In [1]:
# Huggingface token for meta licence acceptance
from dotenv import load_dotenv
import os

load_dotenv()
HF_TOKEN = os.getenv('HF_TOKEN')

MODEL_ID = "meta-llama/Llama-2-7b-chat-hf"

In [2]:

import torch
import lightning as L
from torch import optim
from transformers import AutoModelForCausalLM, AutoTokenizer, Conversation, LlamaForCausalLM, LlamaTokenizerFast

class LitLlamaChat(L.LightningModule):
    def __init__(self, llm: LlamaForCausalLM, tokenizer: LlamaTokenizerFast):
        super().__init__()
        self.llm, self.tokenizer = llm, tokenizer
    
    def training_step(self, batch, batch_idx):
        outputs = self.llm(**batch)
        loss = outputs.loss
        self.log("train_loss", loss)
        return loss
    
    def configure_optimizers(self):
        return optim.AdamW(self.parameters(), lr=5e-5)
    
    def validation_step(self, batch, batch_idx):
        outputs = self.llm(**batch)
        self.log("val_loss", outputs.loss)

    def generate(self, batch: list[list[int]], *args, **kwargs):
        # argmax (greedy search) by default
        return self.llm.generate(torch.tensor(batch, dtype=torch.long), *args, **kwargs)

    def forward(self, batch: list[list[int]]):
        self.eval()
        with torch.no_grad():
            out = self.llm.forward(input_ids=torch.tensor(batch, dtype=torch.long))
        next_token_logits = out.logits[:,-1]    # B,1,V
        return next_token_logits
    
    def generate_from_conversations(self, batch: list[Conversation], *args, **kwargs):
        toks_in = [self.tokenizer.apply_chat_template(conv) for conv in batch]
        toks_out = self.generate(toks_in, *args, **kwargs)
        self.tokenizer
        return self.tokenizer.batch_decode(toks_out)
    
model = LitLlamaChat(
    AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN),
    AutoTokenizer.from_pretrained(MODEL_ID)
)
        

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:18<00:00,  9.34s/it]


In [3]:
batch = [Conversation([{"role":"system","content": "Answer the following questions:"},{"role":"user","content": "Who is God?"},])]
toks_in = [model.tokenizer.apply_chat_template(conv) for conv in batch]
print(model.tokenizer.batch_decode(toks_in))

['<s> [INST] <<SYS>>\nAnswer the following questions:\n<</SYS>>\n\nWho is God? [/INST]']


In [4]:
print(model.generate_from_conversations(batch, max_new_tokens=10)[0])

<s> [INST] <<SYS>>
Answer the following questions:
<</SYS>>

Who is God? [/INST]  I'm just an AI, I
