In [None]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import json
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torch.optim import AdamW



In [None]:
import subprocess

def get_gpu_memory_info():
    try:
        result = subprocess.check_output(
            ['nvidia-smi', '--query-gpu=name,memory.total,memory.used,memory.free', '--format=csv,nounits,noheader'],
            encoding='utf-8'
        )
        # Parse the result
        lines = result.strip().split('\n')
        info = [line.split(', ') for line in lines]
        for idx, (name, total, used, free) in enumerate(info):
            print(f"GPU {idx}: {name}")
            print(f"  Total memory: {int(total) / 1024:.2f} GB")
            print(f"  Used memory: {int(used) / 1024:.2f} GB")
            print(f"  Free memory: {int(free) / 1024:.2f} GB")
    except subprocess.CalledProcessError as e:
        print(f"An error occurred: {e}")

get_gpu_memory_info()


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, padding = False)

# Configure quantization
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
)

# Load the model with quantization
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=quantization_config
)
tokenizer.pad_token = tokenizer.eos_token



In [None]:
class InstructionDataset(Dataset):
    def __init__(self, instances, tokenizer):
        self.instances = instances
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx):
        instance = self.instances[idx]
        input_text = instance["input"]
        output_text = instance["output"]

        input_encoding = self.tokenizer(
            input_text,
            padding=False,
            truncation=False,
            return_tensors="pt"
        )
        output_encoding = self.tokenizer(
            output_text,
            padding=False,
            truncation=False,
            return_tensors="pt"
        )

        labels = output_encoding.input_ids

        return {
            "input_ids": input_encoding.input_ids.squeeze(),
            "attention_mask": input_encoding.attention_mask.squeeze(),
            "labels": labels.squeeze(),
            "original_text": input_text
        }

def prepare_data(data_dir, tasks_dir, number_of_tasks):
    def load_task_names(file_path):
        with open(file_path, 'r', encoding='utf-8') as file:
            task_names = file.readlines()
        return [task.strip() for task in task_names]

    def load_task_data(task_name):
        with open(f"{tasks_dir}/{task_name}.json", 'r', encoding='utf-8') as file:
            data = json.load(file)
        return data

    def extract_instances(data):
        instances = [{
            "input": data.get("Definition")[0]+ "\n" + instance["input"] + "\n",
            "output": instance["output"][0]
        } for instance in data.get("Instances", [])]
        return instances

    tasks = load_task_names(data_dir)

    all_instances = []
    for task in tasks[:number_of_tasks]:
        task_data = load_task_data(task)
        instances = extract_instances(task_data)
        all_instances.extend(instances)

    return all_instances

train_data_dir = "./splits/default/train_tasks.txt"
test_data_dir = "./splits/default/test_tasks.txt"
tasks_dir = "./tasks/"
instances = prepare_data(train_data_dir, tasks_dir, 3)
test_instances = prepare_data(test_data_dir, tasks_dir, 1)
# Create the dataset and dataloader
dataset = InstructionDataset(instances, tokenizer)

test_dataset = InstructionDataset(test_instances, tokenizer)

def collate_fn(batch):
    input_ids = [item['input_ids'] for item in batch]
    attention_masks = [item['attention_mask'] for item in batch]
    labels = [item['labels'] for item in batch]

    # Dynamically pad the sequences
    input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.eos_token_id)
    attention_masks = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True, padding_value=0)
    labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)

    return {
        'input_ids': input_ids,
        'attention_mask': attention_masks,
        'labels': labels
    }

dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)


In [None]:
class DecoderOnlyModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, hidden_dim, dropout=0.1):
        super(DecoderOnlyModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(512, embed_dim)
        decoder_layer = nn.TransformerDecoderLayer(embed_dim, num_heads, hidden_dim, dropout)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.output_layer = nn.Linear(embed_dim, vocab_size)

    def forward(self, input_ids):
        positions = torch.arange(0, input_ids.size(1), device=input_ids.device).unsqueeze(0).repeat(input_ids.size(0), 1)
        x = self.embedding(input_ids) + self.position_embedding(positions)
        x = self.dropout(x)

        seq_len = x.size(1)
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
        causal_mask = causal_mask.float().masked_fill(causal_mask, float('-inf'))

        x = x.transpose(0, 1)
        memory = torch.zeros((0, x.size(1), x.size(2)), device=x.device)  
        x = self.transformer_decoder(x, memory, tgt_mask=causal_mask).transpose(0, 1)
        
        x = self.layer_norm(x)
        logits = self.output_layer(x)
        return logits

    def generate(self, input_ids, max_length):
        generated_ids = input_ids
        for _ in range(max_length - input_ids.size(1)):
            seq_len = generated_ids.size(1)
            causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=generated_ids.device), diagonal=1).bool()
            causal_mask = causal_mask.float().masked_fill(causal_mask, float('-inf'))

            outputs = self.forward(generated_ids)
            next_token_logits = outputs[:, -1, :]
            next_token_ids = next_token_logits.argmax(dim=-1).unsqueeze(-1)
            generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1)
            
            # Check for eos_token_id in the batch
            if (next_token_ids == tokenizer.eos_token_id).any():
                break
        return generated_ids

vocab_size = len(tokenizer)
embed_dim = 1144
num_heads = 8
num_layers = 8
hidden_dim = 2288
dropout = 0.1
learning_rate = 2e-5
weight_decay = 0.01
batch_size = 8
epochs = 6


In [None]:
def combined_loss(logits, hard_labels, soft_logits, alpha=0.5, temperature=1.0):
    logits = logits / temperature
    soft_logits = soft_logits / temperature

    kl_loss = nn.functional.kl_div(
        torch.log_softmax(logits, dim=-1), 
        torch.softmax(soft_logits, dim=-1), 
        reduction='batchmean'
    ) * (temperature ** 2)

    ce_loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), hard_labels.view(-1))

    return alpha * kl_loss + (1 - alpha) * ce_loss

def evaluate_model(model, dataloader, tokenizer):
    model.eval()
    total_loss = 0
    all_references = []
    all_candidates = []
    criterion = nn.CrossEntropyLoss(ignore_index=-100)

    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            if i % 10 == 0: 
                print(i)
                
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask']
            labels = batch['labels'].to(device)

            generated_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False, early_stopping = True, num_beams = 3, attention_mask = attention_mask)

            outputs = model(input_ids)
            logits = outputs.logits if hasattr(outputs, 'logits') else outputs

            if logits.size(1) != labels.size(1):
                logits = logits[:, :labels.size(1), :]

            for j in range(labels.size(0)):
                input_ids_j = input_ids[j]
                label_ids = labels[j]
                pred_ids = generated_ids[j]

                input_ids_text = tokenizer.decode(
                    [id for id in input_ids_j if 0 <= id < tokenizer.vocab_size and id not in tokenizer.all_special_ids], 
                    skip_special_tokens=True
                )
                label_text = tokenizer.decode(
                    [id for id in label_ids if 0 <= id < tokenizer.vocab_size and id not in tokenizer.all_special_ids], 
                    skip_special_tokens=True
                )
                pred_text = tokenizer.decode(
                    [id for id in pred_ids if 0 <= id < tokenizer.vocab_size and id not in tokenizer.all_special_ids], 
                    skip_special_tokens=True
                )
                print("pred:")
                print(pred_text[len(input_ids_text):])
                print("end pred:")

                all_references.append(label_text.split())
                all_candidates.append(pred_text.split())

    smoothing_function = SmoothingFunction().method1
    bleu_score = corpus_bleu(all_references, all_candidates, smoothing_function=smoothing_function)

    return bleu_score

def load_model(model_class, tokenizer_class, model_path, config_class, pretrained_model_name):
    config = config_class.from_pretrained(model_path)
    model = model_class.from_pretrained(model_path, config=config)
    model.load_state_dict(torch.load(model_path, map_location=device))
    tokenizer = tokenizer_class.from_pretrained(pretrained_model_name)
    return model, tokenizer

def train_student(student, dataloader, teacher_model, model_name, epochs=2, alpha=0.5, temperature=1.0, learning_rate=5e-5, weight_decay=0.01):
    student.train()
    student.to(device)
    teacher_model.eval()

    optimizer = AdamW(student.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(epochs):
        total_loss = 0
        epoch_iterator = tqdm(dataloader, desc=f'Epoch {epoch + 1}/{epochs}', unit='batch')
        for i, batch in enumerate(epoch_iterator):
            input_ids = batch['input_ids'].to(device, non_blocking=True)
            attention_mask = batch['attention_mask'].to(device, non_blocking=True)

            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                logits = student(input_ids)

                with torch.no_grad():
                    teacher_logits = teacher_model(input_ids, attention_mask).logits
                    hard_labels = teacher_logits.view(-1, teacher_logits.size(-1)).argmax(dim=-1).to(device)

                total_batch_loss = combined_loss(logits, hard_labels, teacher_logits, alpha, temperature)
            
            scaler.scale(total_batch_loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_loss += total_batch_loss.item()

            epoch_iterator.set_postfix(loss=total_loss / (i + 1))
            torch.cuda.empty_cache()

        print(f'Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataloader)}')

    torch.save(student.state_dict(), f"{model_name}.pt")
    print(f"Model saved as {model_name}.pt")


In [None]:
student_model = DecoderOnlyModel(vocab_size, embed_dim, num_heads, num_layers, hidden_dim, dropout).to(device)
#teacher_model = GPT2LMHeadModel.from_pretrained("./teacher_model_fine_tuned").to(device)

In [None]:
train_student(student_model, dataloader, model, model_name="student model" , epochs=5, alpha=0.5)

In [None]:
print(evaluate_model(model, test_dataloader, tokenizer))

In [None]:
input_text = "Comment: Me too. He might have gotten away, but it wouldn't be with my cat or all his body parts. I would have scratched his eyes out and pummeled his testicles into mush!"
definition = 'Task: Classify the following comment as sexual-explicit or non-sexual-explicit. A comment is considered sexual-explicit if it explicitly portrays sexual matters. give me only the classafaction, without an explanation.'
system = 'System: This is a classification task. You need to classify the comment based on its content.'

# Combine the components into a prompt
prompt = f"{definition}\n{input_text}"

# Encode the prompt with padding
encoding = tokenizer(prompt, return_tensors='pt', padding=False).to(device)

input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)

# Ensure inputs are in float32
input_ids = input_ids

# Function to stop at first newline
def generate_until_newline(model, tokenizer, input_ids, max_new_tokens=50):
    generated_text = ""
    for _ in range(max_new_tokens):
        outputs = model.generate(
            input_ids.to(torch.int64),  # Convert back to int64 for generation
            max_new_tokens=1,  # Generate one token at a time
            eos_token_id=tokenizer.eos_token_id,
            do_sample=False,  # Deterministic generation
            temperature=1,  # Set temperature to 0 for greedy decoding
            early_stopping=True,
            pad_token_id=tokenizer.pad_token_id,  # Ensure padding is handled correctly
        )
        new_token = outputs[:, -1:]
        new_token_text = tokenizer.decode(new_token[0], skip_special_tokens=True)
        

        generated_text += new_token_text
        input_ids = torch.cat([input_ids, new_token.float()], dim=-1)
    return generated_text

# Generate the output
generated_text = generate_until_newline(model, tokenizer, input_ids)
print(generated_text)
