In [None]:
%pip install torch==2.3.0 transformers==4.38.0 tqdm==4.66.4 accelerate

In [ ]:
%pip install huggingface_hub

In [ ]:
from huggingface_hub import login

login(token="")

In [ ]:
import json
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, DataCollatorWithPadding
from torch.cuda.amp import autocast, GradScaler
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
from tqdm import tqdm
import accelerate

In [ ]:
class QwenDataset(Dataset):
    def __init__(self, data_path, tokenizer, max_length=256):
        if tokenizer.pad_token_id is None:
            tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
        self.tokenizer = tokenizer

        self.data = self.load_data(data_path)
        self.max_length = max_length

    def load_data(self, data_path):
        with open(data_path, 'r', encoding='utf-8') as f:
            return [json.loads(line) for line in f]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        input_text = item['input_text']
        output_text = item['output_text']

        tokenized_input = self.tokenizer(
            input_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        tokenized_output = self.tokenizer(
            output_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        input_ids = tokenized_input['input_ids'].squeeze(0)
        attention_mask = tokenized_input['attention_mask'].squeeze(0)
        labels = tokenized_output['input_ids'].squeeze(0)

        labels = labels.masked_fill(labels == self.tokenizer.pad_token_id, -100)

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

In [ ]:
def create_dataloaders(data_path, tokenizer_name, batch_size, max_length=512):
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    dataset = QwenDataset(data_path, tokenizer, max_length)
    collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collator,
        pin_memory=True
    )
    return dataloader, tokenizer

In [ ]:
class PromptTuning(nn.Module):
    def __init__(self, model_name, num_virtual_tokens=10, prompt_embedding_init=None):
        super().__init__()
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            device_map=None
        )
        self.model.gradient_checkpointing_enable()

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        self.num_virtual_tokens = num_virtual_tokens
        self.prompt_embeddings = nn.Embedding(num_virtual_tokens, self.model.config.hidden_size)

        if prompt_embedding_init is not None:
            self.prompt_embeddings.weight = nn.Parameter(prompt_embedding_init)

    def forward(self, input_ids, attention_mask, labels=None):
        batch_size = input_ids.shape[0]
        inputs_embeds = self.model.get_input_embeddings()(input_ids)

        prompt_embeds = self.prompt_embeddings(torch.arange(self.num_virtual_tokens, device=input_ids.device))
        prompt_embeds = prompt_embeds.unsqueeze(0).expand(batch_size, -1, -1)

        inputs_embeds = torch.cat([prompt_embeds, inputs_embeds], dim=1)

        prompt_attention_mask = torch.ones(batch_size, self.num_virtual_tokens, dtype=torch.long, device=input_ids.device)
        attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1)

        if labels is not None:
            prompt_labels = torch.full((batch_size, self.num_virtual_tokens), -100, dtype=torch.long, device=input_ids.device)
            labels = torch.cat([prompt_labels, labels], dim=1)

        outputs = self.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
        return outputs

    def save_prompt_embeddings(self, path):
        torch.save(self.prompt_embeddings.state_dict(), path)

    def load_prompt_embeddings(self, path):
        self.prompt_embeddings.load_state_dict(torch.load(path))

In [ ]:
def train(model_name, data_path, num_virtual_tokens, lr, batch_size, num_epochs, output_dir, accumulation_steps=4):
    dataloader, tokenizer = create_dataloaders(data_path, model_name, batch_size)

    model = PromptTuning(model_name, num_virtual_tokens=num_virtual_tokens)
    for name, param in model.named_parameters():
        if "prompt_embeddings" not in name:
            param.requires_grad = False

    optimizer = torch.optim.AdamW(model.prompt_embeddings.parameters(), lr=lr)
    scaler = GradScaler()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    model.train()
    for epoch in range(1, num_epochs + 1):
        total_loss = 0.0
        optimizer.zero_grad()
        for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch}/{num_epochs}", leave=False), 1):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            with autocast(dtype=torch.bfloat16):
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss / accumulation_steps

            scaler.scale(loss).backward()

            if step % accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            total_loss += loss.item() * accumulation_steps

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch} completed — Avg Loss: {avg_loss:.4f}")
        torch.cuda.empty_cache()

    save_path = os.path.join(output_dir, "prompt_embeddings.pt")
    model.save_prompt_embeddings(save_path)
    print(f"Prompt embeddings saved to {save_path}")

In [ ]:
model_name = "mistralai/Mistral-7B-v0.1"
data_path = "./prompt_tuning_dataset.jsonl"
num_virtual_tokens = 10
lr = 1e-3
batch_size = 4
num_epochs = 3
output_dir = "./output"

os.makedirs(output_dir, exist_ok=True)

train(model_name, data_path, num_virtual_tokens, lr, batch_size, num_epochs, output_dir)