In [None]:
%pip install torch>=2.1.0 transformers>=4.34.0 peft>=0.5.0 accelerate>=0.20.0

In [ ]:
%pip install huggingface_hub

In [ ]:
from huggingface_hub import login

login(token="")

In [ ]:
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, PrefixTuningConfig, TaskType
from torch.utils.data import Dataset, DataLoader
import json

In [ ]:
class ChatDataset(Dataset):
    def __init__(self, tokenizer, data_path, max_length=512):
        self.tokenizer = tokenizer
        self.data = []
        self.max_length = max_length

        with open(data_path, 'r', encoding='utf-8') as f:
            for line in f:
                ex = json.loads(line)
                prompt = ex['input_text'].strip()
                completion = ex['output_text'].strip()
                text = prompt + completion
                self.data.append(text)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]
        tokens = self.tokenizer(
            text,
            max_length=self.max_length,
            truncation=True,
            return_tensors='pt'
        )
        return {
            'input_ids': tokens.input_ids.squeeze(0),
            'attention_mask': tokens.attention_mask.squeeze(0),
        }

In [ ]:
def collate_fn(batch):
    input_ids = [b['input_ids'] for b in batch]
    attention_mask = [b['attention_mask'] for b in batch]

    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
    )
    attention_mask = torch.nn.utils.rnn.pad_sequence(
        attention_mask, batch_first=True, padding_value=0
    )

    labels = input_ids.clone()
    return {'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels}

In [ ]:
MODEL_NAME = "mistralai/Mistral-7B-v0.1"
DATA_PATH = "./prefix_tuning_dataset.jsonl"
OUTPUT_DIR = "./mistral_prefix_tuned"
BATCH_SIZE = 4
NUM_EPOCHS = 3
LR = 2e-5
NUM_VIRTUAL_TOKENS = 50
MAX_LENGTH = 256

In [ ]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map=None
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

peft_config = PrefixTuningConfig(
    peft_type="PREFIX_TUNING",
    task_type=TaskType.CAUSAL_LM,
    num_virtual_tokens=50,
    encoder_hidden_size=model.config.hidden_size
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

In [ ]:
train_dataset = ChatDataset(tokenizer, DATA_PATH, max_length=MAX_LENGTH)
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn
)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

In [ ]:
model.train()
for epoch in range(NUM_EPOCHS):
    total_loss = 0.0
    for step, batch in enumerate(train_loader):
        for k in batch:
            batch[k] = batch[k].to(model.device)

        optimizer.zero_grad()
        out = model(**batch)
        loss = out.loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if step % 10 == 0:
            print(f"Epoch {epoch+1} Step {step} — loss {loss.item():.4f}")

    avg = total_loss / len(train_loader)
    print(f"=== Epoch {epoch+1} done — avg loss {avg:.4f} ===")

In [ ]:
os.makedirs(OUTPUT_DIR, exist_ok=True)
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)