In [1]:
import os
import pandas as pd
import numpy as np
import torch
import os
from dotenv import load_dotenv
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from sklearn.model_selection import train_test_split
from datetime import datetime
import time
import json

# Load environment variables from .env file
load_dotenv()

# Get token from environment
HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
if not HF_TOKEN:
    raise ValueError("Please set the HUGGINGFACE_TOKEN environment variable")
os.environ["HF_TOKEN"] = HF_TOKEN

# Dataset class for few-shot approach
class MorphologyFewShotDataset(Dataset):
    def __init__(self, data, tokenizer=None, max_length=512, num_examples=3):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.num_examples = num_examples
        self.examples = []
        
        for _, row in data.iterrows():
            try:
                # Get few-shot examples
                few_shot_examples = self.get_few_shot_examples(row)
                
                # Create prompt with few-shot examples
                prompt = f"Question about morphology. Here are some examples:\n\n"
                prompt += few_shot_examples
                prompt += f"\nNow answer this question:\n"
                prompt += f"Task: {row.get('Task', 'Identify')}\n"
                prompt += f"Word: {row.get('Word', '')}\n"
                prompt += f"Question: {row.get('Instruction', '')}\n"
                
                # Add choices if available
                if pd.notna(row.get('Choice_1', pd.NA)):
                    prompt += "Choices:\n"
                    choice_num = 1
                    while True:
                        choice_key = f'Choice_{choice_num}'
                        if choice_key not in row or pd.isna(row[choice_key]):
                            break
                        prompt += f"{choice_num}. {row[choice_key]}\n"
                        choice_num += 1
                
                # Create expected output
                if pd.notna(row.get('Correct_Answer', pd.NA)):
                    correct_answer = str(row['Correct_Answer'])
                    answer = f"Answer: Choice {correct_answer}"
                else:
                    answer = "Answer: Unable to determine"
                
                # Store example
                self.examples.append({
                    'prompt': prompt,
                    'answer': answer
                })
            except Exception as e:
                print(f"Error processing row: {e}")
        
        print(f"Created {len(self.examples)} few-shot training examples")
    
    def get_few_shot_examples(self, current_row):
        """Get few-shot examples similar to the current question"""
        # Try to find similar examples (same category and task)
        similar_rows = self.data[
            (self.data['Category'] == current_row.get('Category', '')) & 
            (self.data['Task'] == current_row.get('Task', '')) &
            (self.data.index != current_row.name)  # Exclude current row
        ]
        
        # If not enough, try just the same category
        if len(similar_rows) < self.num_examples:
            similar_rows = self.data[
                (self.data['Category'] == current_row.get('Category', '')) &
                (self.data.index != current_row.name)
            ]
        
        # If still not enough, use random examples
        if len(similar_rows) < self.num_examples:
            similar_rows = self.data[self.data.index != current_row.name]
        
        # Sample examples
        if len(similar_rows) <= self.num_examples:
            example_rows = similar_rows
        else:
            example_rows = similar_rows.sample(n=self.num_examples)
        
        # Format examples
        examples_text = ""
        for i, (_, example) in enumerate(example_rows.iterrows(), 1):
            examples_text += f"Example {i}:\n"
            examples_text += f"Task: {example.get('Task', 'Identify')}\n"
            examples_text += f"Word: {example.get('Word', '')}\n"
            examples_text += f"Question: {example.get('Instruction', '')}\n"
            
            if pd.notna(example.get('Choice_1', pd.NA)):
                examples_text += "Choices:\n"
                choice_num = 1
                while True:
                    choice_key = f'Choice_{choice_num}'
                    if choice_key not in example or pd.isna(example[choice_key]):
                        break
                    examples_text += f"{choice_num}. {example[choice_key]}\n"
                    choice_num += 1
            
            if pd.notna(example.get('Correct_Answer', pd.NA)):
                correct_answer = str(example['Correct_Answer'])
                examples_text += f"Answer: Choice {correct_answer}\n\n"
            else:
                examples_text += f"Answer: {example.get('Correct_Answer', 'Unknown')}\n\n"
        
        return examples_text
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        return self.examples[idx]

# Simple tokenization function
def tokenize_data(examples, tokenizer, max_length=512):
    """Tokenize a batch of examples"""
    prompts = [ex['prompt'] for ex in examples]
    answers = [ex['answer'] for ex in examples]
    
    # Tokenize inputs
    inputs = tokenizer(
        prompts,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )
    
    # Tokenize targets
    with tokenizer.as_target_tokenizer():
        targets = tokenizer(
            answers,
            padding="max_length",
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        )
    
    # Create input_ids and labels
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask
    labels = targets.input_ids
    
    # Replace padding token id with -100 so it's ignored in loss
    labels[labels == tokenizer.pad_token_id] = -100
    
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

# Basic training function
def train_model(train_dataloader, val_dataloader, model, tokenizer, 
               num_epochs=3, learning_rate=5e-5, output_dir="gemma_few_shot"):
    """Train the model with a basic training loop"""
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Setup optimizer
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    
    # Move model to GPU if available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    # Training loop
    global_step = 0
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        model.train()
        epoch_loss = 0
        
        # Training
        for batch_idx, batch in enumerate(train_dataloader):
            # Prepare batch
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            
            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            loss = outputs.loss
            epoch_loss += loss.item()
            
            # Backward pass
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            global_step += 1
            
            # Log progress
            if batch_idx % 10 == 0:
                print(f"  Batch {batch_idx}: loss = {loss.item():.4f}")
        
        # Calculate average loss for the epoch
        avg_train_loss = epoch_loss / len(train_dataloader)
        train_losses.append(avg_train_loss)
        print(f"  Average training loss: {avg_train_loss:.4f}")
        
        # Validation
        model.eval()
        val_loss = 0
        
        with torch.no_grad():
            for batch in val_dataloader:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)
                
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                
                val_loss += outputs.loss.item()
        
        # Calculate average validation loss
        avg_val_loss = val_loss / len(val_dataloader)
        val_losses.append(avg_val_loss)
        print(f"  Validation loss: {avg_val_loss:.4f}")
        
        # Save model if it's the best so far
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            print(f"  Saving best model with validation loss: {best_val_loss:.4f}")
            
            # Save model
            model_path = os.path.join(output_dir, "best_model")
            os.makedirs(model_path, exist_ok=True)
            
            # Save model state dict
            torch.save(model.state_dict(), os.path.join(model_path, "pytorch_model.bin"))
            
            # Save config
            model_config = model.config.to_dict()
            with open(os.path.join(model_path, "config.json"), 'w') as f:
                json.dump(model_config, f)
            
            # Save tokenizer
            tokenizer.save_pretrained(model_path)
    
    # Save final model
    final_model_path = os.path.join(output_dir, "final_model")
    os.makedirs(final_model_path, exist_ok=True)
    
    # Save model state dict
    torch.save(model.state_dict(), os.path.join(final_model_path, "pytorch_model.bin"))
    
    # Save config
    model_config = model.config.to_dict()
    with open(os.path.join(final_model_path, "config.json"), 'w') as f:
        json.dump(model_config, f)
    
    # Save tokenizer
    tokenizer.save_pretrained(final_model_path)
    
    # Save training stats
    training_stats = {
        "train_losses": train_losses,
        "val_losses": val_losses,
        "best_val_loss": best_val_loss
    }
    
    with open(os.path.join(output_dir, "training_stats.json"), 'w') as f:
        json.dump(training_stats, f)
    
    print(f"Training completed. Final model saved to {final_model_path}")
    return model, tokenizer

# Main function
def main():
    try:
        # Load libraries dynamically to avoid import errors
        print("Loading required libraries...")
        import_start = time.time()
        
        # First try to import a minimal PyTorch
        import torch
        print(f"PyTorch version: {torch.__version__}")
        
        # Then try to import tokenizer and model classes directly
        from transformers import AutoTokenizer, AutoModelForCausalLM
        
        import_end = time.time()
        print(f"Libraries loaded in {import_end - import_start:.2f} seconds")
        
        # Set paths and parameters
        data_path = 'Data/MC_data_MMA.csv'
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        output_dir = f"gemma_few_shot_{timestamp}"
        os.makedirs(output_dir, exist_ok=True)
        
        # Display GPU information
        print(f"CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            print(f"GPU device: {torch.cuda.get_device_name(0)}")
            print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        
        # Load data
        print("Loading data...")
        df = pd.read_csv(data_path)
        print(f"Loaded {len(df)} examples")
        
        # Split data
        train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
        print(f"Training set: {len(train_df)} examples")
        print(f"Validation set: {len(val_df)} examples")
        
        # Load tokenizer and model
        print("Loading tokenizer and model...")
        model_name = "google/gemma-2b-it"
        tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
        model = AutoModelForCausalLM.from_pretrained(model_name, token=HF_TOKEN)
        
        # Create datasets
        print("Creating few-shot datasets...")
        train_dataset = MorphologyFewShotDataset(train_df, tokenizer, num_examples=3)
        val_dataset = MorphologyFewShotDataset(val_df, tokenizer, num_examples=3)
        
        # Create dataloaders
        print("Creating dataloaders...")
        
        # Collate function
        def collate_fn(batch):
            return tokenize_data(batch, tokenizer)
        
        # Create dataloaders
        train_dataloader = DataLoader(
            train_dataset, 
            batch_size=1,  # Small batch size due to model size
            shuffle=True, 
            collate_fn=collate_fn
        )
        
        val_dataloader = DataLoader(
            val_dataset, 
            batch_size=1, 
            shuffle=False, 
            collate_fn=collate_fn
        )
        
        # Train model
        print("Starting training...")
        model, tokenizer = train_model(
            train_dataloader,
            val_dataloader,
            model,
            tokenizer,
            num_epochs=3,
            learning_rate=5e-5,
            output_dir=output_dir
        )
        
        print("Training completed successfully!")
        
    except Exception as e:
        print(f"An error occurred: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()

Loading required libraries...
PyTorch version: 2.2.2+cu121
Libraries loaded in 1.76 seconds
CUDA available: True
GPU device: NVIDIA A100-SXM4-80GB
GPU memory: 84.97 GB
Loading data...
Loaded 268 examples
Training set: 214 examples
Validation set: 54 examples
Loading tokenizer and model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Creating few-shot datasets...
Created 214 few-shot training examples
Created 54 few-shot training examples
Creating dataloaders...
Starting training...

Epoch 1/3




  Batch 0: loss = 21.2747
  Batch 10: loss = 10.3585
  Batch 20: loss = 2.5222
  Batch 30: loss = 1.3865
  Batch 40: loss = 2.2768
  Batch 50: loss = 3.0543
  Batch 60: loss = 2.1271
  Batch 70: loss = 1.3417
  Batch 80: loss = 0.7835
  Batch 90: loss = 1.2772
  Batch 100: loss = 2.9002
  Batch 110: loss = 2.0829
  Batch 120: loss = 3.5104
  Batch 130: loss = 1.7551
  Batch 140: loss = 1.6231
  Batch 150: loss = 2.5616
  Batch 160: loss = 3.5322
  Batch 170: loss = 1.0038
  Batch 180: loss = 0.7870
  Batch 190: loss = 0.8998
  Batch 200: loss = 2.6117
  Batch 210: loss = 2.1571
  Average training loss: 2.5146
  Validation loss: 2.0888
  Saving best model with validation loss: 2.0888

Epoch 2/3
  Batch 0: loss = 2.0619
  Batch 10: loss = 0.8950
  Batch 20: loss = 0.5877
  Batch 30: loss = 1.6916
  Batch 40: loss = 3.6034
  Batch 50: loss = 2.0802
  Batch 60: loss = 1.7429
  Batch 70: loss = 0.8264
  Batch 80: loss = 3.0027
  Batch 90: loss = 1.0799
  Batch 100: loss = 0.4487
  Batch 110