# Enhanced Transformer Training - TEST VERSION

**Testing Model Saving Functionality**

This is a test version with:
- Small dataset for quick training
- Fewer epochs
- Robust saving mechanisms

## 🚀 Setup and Installation

In [None]:
# Cell 1: Install All Requirements for Enhanced Transformer Training
import subprocess
import sys
import os
from IPython.display import clear_output

print("🚀 Installing Enhanced Transformer Requirements...")
print("=" * 60)

# Update pip first
!pip install --upgrade pip

# Install stable PyTorch version
print("🔥 Installing PyTorch with GPU support...")
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

# Install core ML libraries
print("📊 Installing core ML libraries...")
!pip install numpy pandas scikit-learn matplotlib seaborn plotly

# Install technical analysis libraries
print("📈 Installing technical analysis libraries...")
!pip install ta talib-binary

# Install utilities
print("🔧 Installing utilities...")
!pip install tqdm psutil requests ipywidgets

clear_output(wait=True)
print("✅ All requirements installed successfully!")

# Check GPU availability
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"🚀 Default device: {device}")

if torch.cuda.is_available():
  print(f"💻 GPU: {torch.cuda.get_device_name(0)}")
  print(f"🧠 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
elif torch.backends.mps.is_available():
  print("🍎 Apple Silicon GPU available")
else:
  print("⚠️ No GPU detected, using CPU")

print("\n🎯 Ready for enhanced transformer training!")

In [None]:
# Check GPU availability and handle PyTorch import issues
import torch
import warnings
import sys

# Workaround for PyTorch import issues
try:
    from torch._utils_internal import justknobs_check
except ImportError:
    # Create a dummy function if the import fails
    def justknobs_check(name, default=False):
        return default

# Fix for PyTorch 2.2+ pytree API changes
try:
    # Check if register_pytree_node exists
    from torch.utils._pytree import register_pytree_node
except ImportError:
    # Apply monkey patch for older pytree API
    import torch.utils._pytree as pytree
    if not hasattr(pytree, 'register_pytree_node'):
        def register_pytree_node(*args, **kwargs):
            # Dummy implementation for compatibility
            pass
        pytree.register_pytree_node = register_pytree_node

# Disable torch.compile to avoid more issues
torch._dynamo.config.disable = True

warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"🚀 Using device: {device}")
print(f"🔧 PyTorch version: {torch.__version__}")

# Check if CUDA is available
if torch.cuda.is_available():
    print(f"💻 GPU: {torch.cuda.get_device_name(0)}")
    print(f"🧠 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠️ No GPU detected, using CPU")

In [None]:
# Install required packages
!pip install pandas numpy scikit-learn matplotlib ta talib-binary
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install stable-baselines3

print("✅ Packages installed successfully")

## 📊 Data Loading and Preprocessing (Small Dataset for Testing)

In [None]:
import pandas as pd
import numpy as np
from datetime import datetime
import os

# Load cryptocurrency data
def load_crypto_data(csv_path='sample_crypto_data.csv'):
    """Load and preprocess cryptocurrency data"""
    print(f"📊 Loading data from {csv_path}...")
    
    # Create sample data for testing
    if not os.path.exists(csv_path):
        print("🔧 Creating sample data for testing...")
        dates = pd.date_range(start='2024-01-01', periods=5000, freq='5T')
        np.random.seed(42)
        
        # Generate realistic price data
        base_price = 50000
        returns = np.random.normal(0, 0.001, 5000)
        prices = [base_price]
        
        for ret in returns[1:]:
            prices.append(prices[-1] * (1 + ret))
        
        prices = np.array(prices)
        
        sample_data = pd.DataFrame({
            'tic': ['BTCUSDT'] * 5000,
            'open': prices * (1 + np.random.normal(0, 0.0005, 5000)),
            'high': prices * (1 + np.abs(np.random.normal(0, 0.001, 5000))),
            'low': prices * (1 - np.abs(np.random.normal(0, 0.001, 5000))),
            'close': prices,
            'volume': np.random.uniform(100, 1000, 5000)
        }, index=dates)
        
        sample_data.to_csv(csv_path)
        df = sample_data
    else:
        df = pd.read_csv(csv_path)
    
    print(f"✅ Raw data shape: {df.shape}")
    
    # Handle datetime index
    if 'date' in df.columns:
        df['date'] = pd.to_datetime(df['date'])
        df.set_index('date', inplace=True)
    elif 'timestamp' in df.columns:
        df['date'] = pd.to_datetime(df['timestamp'])
        df.set_index('date', inplace=True)
    else:
        # Create datetime index for sample data
        if len(df) > 0:
            dates = pd.date_range(start='2024-01-01', periods=len(df), freq='5T')
            df.index = dates
    
    print(f"📅 Date range: {df.index.min()} to {df.index.max()}")
    print(f"💰 Symbols: {df['tic'].unique() if 'tic' in df.columns else 'Unknown'}")
    
    return df

# Load data
df = load_crypto_data()

if df is not None:
    display(df.head())
    print(f"\n📋 Data info:")
    display(df.info())

## 🔧 Enhanced Feature Engineering

In [None]:
# Import enhanced features module
import sys
import os
sys.path.append(os.getcwd())

from enhanced_features import calculate_enhanced_features, select_important_features

# Calculate enhanced features
def process_features(df):
    """Process and enhance features for training"""
    print("🔧 Calculating enhanced features...")
    
    # Calculate enhanced features
    enhanced_df = calculate_enhanced_features(df)
    print(f"✅ Enhanced features shape: {enhanced_df.shape}")
    
    # Select important features
    selected_features = select_important_features(enhanced_df, n_features=25)  # Reduced for testing
    print(f"🎯 Selected features shape: {selected_features.shape}")
    
    # Handle missing values
    selected_features = selected_features.fillna(method='ffill').fillna(method='bfill').fillna(0)
    
    print(f"✅ Final processed features shape: {selected_features.shape}")
    print(f"📋 Feature columns: {list(selected_features.columns[:10])}...")
    
    return selected_features

# Process features
if df is not None:
    features_df = process_features(df)
    display(features_df.head())

## 🧠 Enhanced Transformer Model

In [None]:
# Import enhanced transformer
from transformer_enhanced_v2 import EnhancedCryptoTransformer, create_enhanced_transformer_config
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim

# Create model configuration
config = create_enhanced_transformer_config()
# Modify config for testing
config['model_params']['max_seq_len'] = 50  # Reduced for testing
config['training_params']['n_epochs'] = 5  # Reduced for testing
config['training_params']['batch_size'] = 16  # Reduced for testing

print("📋 Model configuration:")
for key, value in config.items():
    print(f"   {key}: {value}")

# Create model
if 'features_df' in locals():
    input_dim = features_df.shape[1]
    model = EnhancedCryptoTransformer(
        input_dim=input_dim,
        **config['model_params']
    ).to(device)
    
    print(f"\n🧠 Model created successfully!")
    print(f"📊 Input dimension: {input_dim}")
    print(f"🔧 Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"💾 Model size: {sum(p.numel() for p in model.parameters()) * 4 / 1024 / 1024:.1f} MB")
else:
    print("⚠️ Features not available, creating test model")
    model = EnhancedCryptoTransformer(
        input_dim=25,
        **config['model_params']
    ).to(device)
    print(f"🧠 Test model created with 25 input dimensions")

In [None]:
# Test model forward pass
def test_model(model, input_dim=25):
    """Test model forward pass"""
    print("🧪 Testing model forward pass...")
    
    # Create test input
    batch_size = 4
    seq_len = config['model_params']['max_seq_len']
    test_input = torch.randn(batch_size, seq_len, input_dim).to(device)
    
    # Create multi-scale inputs
    scale_inputs = {
        5: torch.randn(batch_size, seq_len, input_dim).to(device),
        15: torch.randn(batch_size, seq_len//3, input_dim).to(device),
        30: torch.randn(batch_size, seq_len//6, input_dim).to(device),
    }
    
    model.eval()
    with torch.no_grad():
        outputs = model(test_input, scale_inputs)
    
    print("✅ Model forward pass successful!")
    print("📊 Output shapes:")
    for key, value in outputs.items():
        if isinstance(value, torch.Tensor):
            print(f"   {key}: {value.shape}")
    
    return outputs

# Test model
test_outputs = test_model(model, input_dim if 'features_df' in locals() else 25)

## 🏋️‍♂️ Training Setup

In [None]:
# Create dataset and dataloader
class CryptoDataset(torch.utils.data.Dataset):
    """Dataset for cryptocurrency trading"""
    def __init__(self, features_df, sequence_length=50, prediction_horizon=5):
        self.features = features_df.values
        self.sequence_length = sequence_length
        self.prediction_horizon = prediction_horizon
        self.close_prices = features_df['close'].values if 'close' in features_df.columns else self.features[:, 0]
        
        self.sequences, self.targets = self._prepare_sequences()
    
    def _prepare_sequences(self):
        """Prepare training sequences"""
        sequences = []
        targets = []
        
        for i in range(len(self.features) - self.sequence_length - self.prediction_horizon):
            # Input sequence
            seq = self.features[i:i + self.sequence_length]
            sequences.append(seq)
            
            # Target (future return)
            current_price = self.close_prices[i + self.sequence_length - 1]
            future_price = self.close_prices[i + self.sequence_length + self.prediction_horizon - 1]
            target_return = (future_price - current_price) / current_price
            targets.append(target_return)
        
        return np.array(sequences), np.array(targets)
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        sequence = torch.FloatTensor(self.sequences[idx])
        target = torch.FloatTensor([self.targets[idx]])
        return sequence, target

# Create datasets
if 'features_df' in locals():
    print("🔧 Creating datasets...")
    
    # Create dataset
    full_dataset = CryptoDataset(features_df, sequence_length=config['model_params']['max_seq_len'])
    
    # Split data
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size]
    )
    
    print(f"📊 Training samples: {len(train_dataset)}")
    print(f"📊 Validation samples: {len(val_dataset)}")
    
    # Create dataloaders
    batch_size = config['training_params']['batch_size']
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0  # Set to 0 to avoid multiprocessing issues
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0
    )
    
    print(f"🔧 Batch size: {batch_size}")
    print(f"🔧 Training batches: {len(train_loader)}")
    print(f"🔧 Validation batches: {len(val_loader)}")
else:
    print("⚠️ Features not available, skipping dataset creation")

## 🚀 Training Loop (Test Version)

In [None]:
# Training setup
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import time
import warnings

# Suppress PyTorch warnings
warnings.filterwarnings('ignore', category=UserWarning)

# Initialize training components
try:
    if 'train_loader' in locals():
        # Optimizer and scheduler
        optimizer = optim.AdamW(
            model.parameters(),
            lr=config['training_params']['learning_rate'],
            weight_decay=config['training_params']['weight_decay']
        )
        
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=config['training_params']['n_epochs'],
            eta_min=config['training_params']['learning_rate'] * 0.1
        )
        
        # Loss functions
        action_loss_fn = nn.MSELoss()
        confidence_loss_fn = nn.MSELoss()
        
        # Training history
        training_history = {
            'train_loss': [],
            'val_loss': [],
            'learning_rate': [],
            'epoch_time': [],
            'gpu_memory': []
        }
        
        print("🚀 Training setup completed!")
    else:
        print("⚠️ Training setup skipped - no datasets available")
except Exception as e:
    print(f"⚠️ Error setting up training: {str(e)}")
    print("This might be due to PyTorch compatibility issues. Please restart the kernel and try again.")

In [None]:
# Training function
def train_epoch(model, train_loader, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    action_loss = 0
    confidence_loss = 0
    num_batches = 0
    
    for batch_idx, (sequences, targets) in enumerate(train_loader):
        sequences = sequences.to(device)
        targets = targets.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(sequences)
        
        # Calculate losses
        action_loss_batch = action_loss_fn(outputs['action'], targets)
        confidence_loss_batch = confidence_loss_fn(outputs['confidence'], torch.ones_like(outputs['confidence']) * 0.8)
        
        # Total loss
        total_loss_batch = action_loss_batch + 0.2 * confidence_loss_batch
        
        # Backward pass
        total_loss_batch.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        # Accumulate losses
        total_loss += total_loss_batch.item()
        action_loss += action_loss_batch.item()
        confidence_loss += confidence_loss_batch.item()
        num_batches += 1
        
        if batch_idx % 5 == 0:  # Reduced frequency for testing
            print(f"  Batch {batch_idx}/{len(train_loader)}: Loss = {total_loss_batch.item():.4f}")
    
    return {
        'total_loss': total_loss / num_batches,
        'action_loss': action_loss / num_batches,
        'confidence_loss': confidence_loss / num_batches
    }

def validate_epoch(model, val_loader, device):
    """Validate for one epoch"""
    model.eval()
    total_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        for sequences, targets in val_loader:
            sequences = sequences.to(device)
            targets = targets.to(device)
            
            outputs = model(sequences)
            
            loss = action_loss_fn(outputs['action'], targets)
            total_loss += loss.item()
            num_batches += 1
    
    return total_loss / num_batches

print("🔧 Training functions defined!")

In [None]:
# Start training
def start_training(model, train_loader, val_loader, optimizer, scheduler, config, device):
    """Start the training process"""
    print("🚀 Starting enhanced transformer training (TEST VERSION)...")
    print(f"📊 Training samples: {len(train_loader.dataset)}")
    print(f"📊 Validation samples: {len(val_loader.dataset)}")
    print(f"🧠 Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"🔧 Epochs: {config['training_params']['n_epochs']}")

    best_val_loss = float('inf')
    training_history = {
        'train_loss': [],
        'val_loss': [],
        'learning_rate': [],
        'epoch_time': [],
        'gpu_memory': []
    }

    # Fix PyTorch serialization issues
    import pickle
    import io
    
    class CPU_Unpickler(pickle.Unpickler):
        def find_class(self, module, name):
            if module == 'torch.storage' and name == '_load_from_bytes':
                return lambda b: torch.load(io.BytesIO(b))
            return super().find_class(module, name)
    
    def safe_save(obj, filename):
        """Safely save torch objects with pickle"""
        try:
            # First try normal torch.save with legacy format
            torch.save(obj, filename, pickle_protocol=4, _use_new_zipfile_serialization=False)
            print(f"✅ Successfully saved {filename}")
            return True
        except Exception as e1:
            try:
                # Try with older protocol
                torch.save(obj, filename, pickle_protocol=2, _use_new_zipfile_serialization=False)
                print(f"✅ Successfully saved {filename} (protocol 2)")
                return True
            except Exception as e2:
                print(f"   ⚠️ Failed to save {filename}: {e2}")
                return False
    
    def test_save_methods(model, checkpoint, prefix="test"):
        """Test different save methods"""
        print(f"\n🧪 Testing save methods for {prefix}...")
        
        # Method 1: Full checkpoint
        if safe_save(checkpoint, f'{prefix}_full.pth'):
            print("   ✅ Method 1: Full checkpoint - SUCCESS")
        
        # Method 2: State dict only
        try:
            torch.save(model.state_dict(), f'{prefix}_state.pth', 
                     _use_new_zipfile_serialization=False)
            print("   ✅ Method 2: State dict only - SUCCESS")
        except Exception as e:
            print(f"   ❌ Method 2 failed: {e}")
        
        # Method 3: CPU state dict
        try:
            cpu_model = {k: v.cpu() for k, v in model.state_dict().items()}
            torch.save(cpu_model, f'{prefix}_cpu.pth', 
                     _use_new_zipfile_serialization=False)
            print("   ✅ Method 3: CPU state dict - SUCCESS")
        except Exception as e:
            print(f"   ❌ Method 3 failed: {e}")
        
        # Method 4: Numpy arrays
        try:
            numpy_weights = {}
            for name, param in model.state_dict().items():
                numpy_weights[name] = param.cpu().numpy()
            with open(f'{prefix}_weights.pkl', 'wb') as f:
                pickle.dump(numpy_weights, f)
            print("   ✅ Method 4: Numpy arrays - SUCCESS")
        except Exception as e:
            print(f"   ❌ Method 4 failed: {e}")
        
        print(f"\n📋 Save test summary for {prefix}:")
        for method in ['full', 'state', 'cpu', 'weights']:
            if os.path.exists(f'{prefix}_{method}.pth') or os.path.exists(f'{prefix}_weights.pkl'):
                print(f"   ✅ {method}: Saved")
            else:
                print(f"   ❌ {method}: Failed")

    # Test save methods before training
    test_checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'config': config,
        'training_history': training_history,
        'epoch': 0
    }
    test_save_methods(model, test_checkpoint, "before_training")

    for epoch in range(config['training_params']['n_epochs']):
        start_time = time.time()

        # Training
        train_losses = train_epoch(model, train_loader, optimizer, device)

        # Validation
        val_loss = validate_epoch(model, val_loader, device)

        # Learning rate scheduling
        scheduler.step()

        # Record metrics
        epoch_time = time.time() - start_time
        training_history['train_loss'].append(train_losses['total_loss'])
        training_history['val_loss'].append(val_loss)
        training_history['learning_rate'].append(optimizer.param_groups[0]['lr'])
        training_history['epoch_time'].append(epoch_time)

        # GPU memory usage
        if torch.cuda.is_available():
            gpu_memory = torch.cuda.memory_allocated() / 1024**3
            training_history['gpu_memory'].append(gpu_memory)

        # Print progress
        print(f"\n📊 Epoch {epoch+1}/{config['training_params']['n_epochs']}")
        print(f"   Train Loss: {train_losses['total_loss']:.4f} (Action: {train_losses['action_loss']:.4f})")
        print(f"   Val Loss: {val_loss:.4f}")
        print(f"   LR: {optimizer.param_groups[0]['lr']:.6f}")
        print(f"   Time: {epoch_time:.1f}s")
        if torch.cuda.is_available():
            print(f"   GPU Memory: {gpu_memory:.1f} GB")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            checkpoint = {
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'config': config,
                'training_history': training_history,
                'epoch': epoch
            }
            
            print("   💾 Saving best model...")
            saved = safe_save(checkpoint, 'enhanced_transformer_best.pth')
            
            # Test all save methods
            test_save_methods(model, checkpoint, f"epoch_{epoch+1}_best")

    # Save final model
    print("\n💾 Saving final model...")
    final_checkpoint = {
        'model_state_dict': model.state_dict(),
        'config': config,
        'training_history': training_history
    }
    
    safe_save(final_checkpoint, 'enhanced_transformer_final.pth')
    test_save_methods(model, final_checkpoint, "final")

    print("\n✅ Training completed!")
    print(f"🏆 Best validation loss: {best_val_loss:.4f}")
    
    # List all saved files
    print("\n📋 Saved files:")
    for f in os.listdir('.'):
        if 'enhanced_transformer' in f or f.endswith('.pth') or f.endswith('.pkl'):
            if os.path.getsize(f) > 0:
                print(f"   ✅ {f}: {os.path.getsize(f) / 1024 / 1024:.1f} MB")
            else:
                print(f"   ❌ {f}: Empty file")

    return training_history

# Start training
if 'train_loader' in locals() and 'val_loader' in locals() and 'optimizer' in locals():
    training_history = start_training(model, train_loader, val_loader, optimizer, scheduler, config, device)
else:
    print("⚠️ Training setup not available. Please run all cells above first.")

## 🧪 Test Loading Models

In [None]:
# Test loading saved models
def test_model_loading():
    """Test loading models with different methods"""
    print("🧪 Testing model loading...")
    
    # Test files to load
    test_files = [
        'enhanced_transformer_best.pth',
        'enhanced_transformer_final.pth',
        'final_full.pth',
        'final_cpu.pth',
        'final_state.pth'
    ]
    
    for file in test_files:
        if os.path.exists(file):
            print(f"\n📂 Testing {file}...")
            try:
                checkpoint = torch.load(file, map_location=device)
                print(f"   ✅ Loaded successfully")
                print(f"   📊 File size: {os.path.getsize(file) / 1024 / 1024:.1f} MB")
                
                # Check what's in the checkpoint
                if isinstance(checkpoint, dict):
                    print(f"   📋 Keys: {list(checkpoint.keys())}")
                    if 'model_state_dict' in checkpoint:
                        print(f"   🔧 Model parameters: {sum(p.numel() for p in checkpoint['model_state_dict'].values()):,}")
            except Exception as e:
                print(f"   ❌ Failed to load: {e}")
        else:
            print(f"\n❌ File not found: {file}")

# Run loading test
if 'training_history' in locals():
    test_model_loading()
else:
    print("⚠️ No training history available. Run training first.")

## 📊 Test Summary

In [None]:
# Print test summary
def print_test_summary():
    """Print summary of test results"""
    print("\n" + "="*60)
    print("🧪 ENHANCED TRANSFORMER TRAINING TEST SUMMARY")
    print("="*60)
    
    # Count saved files
    saved_files = []
    failed_methods = []
    
    for f in os.listdir('.'):
        if 'enhanced_transformer' in f or f.endswith(('.pth', '.pkl')):
            if os.path.getsize(f) > 0:
                saved_files.append(f)
            else:
                failed_methods.append(f)
    
    print(f"\n📊 Results:")
    print(f"   ✅ Successfully saved files: {len(saved_files)}")
    print(f"   ❌ Failed/empty files: {len(failed_methods)}")
    
    print(f"\n💾 Saved files:")
    for f in saved_files:
        size_mb = os.path.getsize(f) / 1024 / 1024
        print(f"   - {f}: {size_mb:.1f} MB")
    
    if failed_methods:
        print(f"\n❌ Failed methods:")
        for f in failed_methods:
            print(f"   - {f}")
    
    # Training results
    if 'training_history' in locals() and training_history['train_loss']:
        print(f"\n📈 Training results:")
        print(f"   - Final train loss: {training_history['train_loss'][-1]:.4f}")
        print(f"   - Final val loss: {training_history['val_loss'][-1]:.4f}")
        print(f"   - Total epochs: {len(training_history['train_loss'])}")
        
        # Best method recommendation
        print(f"\n🎯 Recommendations for cloud training:")
        
        if any('full.pth' in f for f in saved_files):
            print("   ✅ Full checkpoint saving works - use this method")
        elif any('state.pth' in f for f in saved_files):
            print("   ✅ State dict saving works - use this as backup")
        elif any('cpu.pth' in f for f in saved_files):
            print("   ✅ CPU saving works - use this if GPU saving fails")
        elif any('weights.pkl' in f for f in saved_files):
            print("   ✅ Numpy saving works - use this as last resort")
        else:
            print("   ❌ All methods failed - need to investigate further")
    
    print("\n" + "="*60)

# Print summary
print_test_summary()