In [7]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import sys

# Konstanten
RANDOM_SEED = 42
BATCH_SIZE = 32
EPOCHS = 40
LEARNING_RATE = 0.0005  # Reduzierte Learning Rate
WEIGHT_DECAY = 0.01     # L2 Regularisierung hinzugefügt
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_PATH = '/home/geiger/asl_detection/machine_learning/models/asl_now/best_model.pth'  # Pfad zum bestehenden Modell
NOISE_LEVEL = 0.05      # Rauschen für Daten-Augmentation

# Dataset-Klasse mit Daten-Augmentation
class HandSignDataset(Dataset):
    def __init__(self, keypoints, labels, augment=False):
        self.keypoints = torch.FloatTensor(keypoints)
        self.labels = torch.LongTensor(labels)
        self.augment = augment
        
    def __len__(self):
        return len(self.keypoints)
    
    def __getitem__(self, idx):
        keypoints = self.keypoints[idx]
        
        # Daten-Augmentation mit zufälligem Rauschen (nur für Trainingsdaten)
        if self.augment:
            # Füge zufälliges Rauschen zu den Keypoints hinzu
            noise = torch.randn_like(keypoints) * NOISE_LEVEL
            keypoints = keypoints + noise
            
            # Zufällige kleine Rotation (simuliere leichte Handdrehung)
            if torch.rand(1).item() > 0.5:
                # Reshape für die 21 Landmarken mit jeweils 3 Koordinaten
                points = keypoints.view(-1, 3)
                
                # Zufälliger kleiner Rotationswinkel
                angle = torch.rand(1).item() * 0.2 - 0.1  # ±0.1 Radiant (ca. ±5.7°)
                
                # Einfache 2D-Rotation in der x-y-Ebene
                cos_a, sin_a = torch.cos(torch.tensor(angle)), torch.sin(torch.tensor(angle))
                rotation = torch.tensor([[cos_a, -sin_a, 0], 
                                        [sin_a, cos_a, 0],
                                        [0, 0, 1]])
                
                # Rotiere die Punkte
                rotated_points = torch.matmul(points, rotation)
                keypoints = rotated_points.view(-1)
        
        return keypoints, self.labels[idx]

# Verbesserte Modell-Definition mit stärkerem Dropout
class HandSignNet(nn.Module):
    def __init__(self, num_classes=24):
        super(HandSignNet, self).__init__()
        
        # Feature Extraction Blocks mit erhöhtem Dropout
        self.features = nn.Sequential(
            nn.Linear(63, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.5),  # Erhöhter Dropout
            
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.5),  # Erhöhter Dropout
            
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.5)   # Erhöhter Dropout
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(128, num_classes)
        )
        
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    epoch_loss = running_loss / len(val_loader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc, all_preds, all_labels

def plot_confusion_matrix(y_true, y_pred, classes):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(15, 12))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=classes, yticklabels=classes)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.savefig('confusion_matrix.png')
    plt.close()

def plot_training_history(train_losses, val_losses, train_accs, val_accs):
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train')
    plt.plot(val_losses, label='Validation')
    plt.title('Loss over epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train')
    plt.plot(val_accs, label='Validation')
    plt.title('Accuracy over epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.close()

def main(load_model=True):
    # Setze Seeds für Reproduzierbarkeit
    torch.manual_seed(RANDOM_SEED)
    np.random.seed(RANDOM_SEED)
    
    # Alphabet-Definition
    alphabet = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y']
    
    # Lade Daten
    print("Lade Daten...")
    data = np.load('/home/geiger/asl_detection/machine_learning/datasets/asl_now/Keypoints_3/asl_keypoints.npz')
    keypoints = data['keypoints']
    labels = data['labels']
    
    # Split Daten
    X_train, X_val, y_train, y_val = train_test_split(
        keypoints, labels, test_size=0.2, random_state=RANDOM_SEED, stratify=labels
    )
    
    # Erstelle DataLoader mit Daten-Augmentation für Training
    train_dataset = HandSignDataset(X_train, y_train, augment=True)  # Augmentation aktiviert
    val_dataset = HandSignDataset(X_val, y_val, augment=False)       # Keine Augmentation für Val
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
    
    # Initialisiere Modell und lade vortrainiertes Modell, falls gewünscht
    print(f"Initialisiere Modell auf {DEVICE}...")
    model = HandSignNet().to(DEVICE)
    
    # Cross-Entropy mit Label Smoothing
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  # Label Smoothing hinzugefügt
    
    # Optimizer mit Weight Decay
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    
    # Angepasster Scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=8, factor=0.5)
    
    initial_val_acc = 0
    if load_model and os.path.exists(MODEL_PATH):
        print(f"Lade vortrainiertes Modell von {MODEL_PATH}...")
        model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
        print("Vortrainiertes Modell erfolgreich geladen!")
        
        # Führe eine Validierung mit dem geladenen Modell durch
        print("Validiere geladenes Modell...")
        initial_val_loss, initial_val_acc, _, _ = validate(model, val_loader, criterion, DEVICE)
        print(f"Initiale Validierungs-Genauigkeit: {initial_val_acc:.2f}%")
    else:
        if load_model:
            print(f"Kein vortrainiertes Modell gefunden unter {MODEL_PATH}. Starte mit neuem Modell.")
        else:
            print("Training mit neuem Modell gestartet.")
    
    best_val_acc = initial_val_acc
    
    # Training
    print("Starte Training...")
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    
    for epoch in range(EPOCHS):
        # Training
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
        
        # Validation
        val_loss, val_acc, val_preds, val_labels = validate(model, val_loader, criterion, DEVICE)
        
        # Learning Rate Anpassung
        scheduler.step(val_loss)
        
        # Speichere Metriken
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        
        # Speichere bestes Modell
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), MODEL_PATH)
            print(f"Neues bestes Modell gespeichert mit Accuracy: {val_acc:.2f}%")
        
        # Ausgabe
        print(f'Epoch {epoch+1}/{EPOCHS}:')
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        print(f'LR: {optimizer.param_groups[0]["lr"]:.6f}')
        print('-' * 50)
    
    # Lade bestes Modell für finale Evaluation, aber nur wenn es existiert
    if os.path.exists(MODEL_PATH):
        print("Lade bestes Modell für finale Evaluation...")
        model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
        _, final_acc, final_preds, final_labels = validate(model, val_loader, criterion, DEVICE)
        
        # Plotte Ergebnisse
        plot_training_history(train_losses, val_losses, train_accs, val_accs)
        plot_confusion_matrix(final_labels, final_preds, alphabet)
    else:
        print("Kein gespeichertes Modell gefunden. Überspringe finale Evaluation.")
        # Verwende die Ergebnisse der letzten Epoche für die Plots
        final_preds = val_preds
        final_labels = val_labels
        plot_training_history(train_losses, val_losses, train_accs, val_accs)
        if len(final_labels) > 0:  # Nur wenn wir Validierungsdaten haben
            plot_confusion_matrix(final_labels, final_preds, alphabet)
    
    print(f"\nBeste Validierungs-Accuracy: {best_val_acc:.2f}%")
    if best_val_acc > initial_val_acc:
        print(f"Verbesserung gegenüber initialem Modell: +{best_val_acc - initial_val_acc:.2f}%")

if __name__ == "__main__":
    # Prüfe, ob ein Kommandozeilenargument übergeben wurde
    load_model = False  # Standard: Starte mit neuem Modell
    
    main(load_model) 

Lade Daten...
Initialisiere Modell auf cpu...
Training mit neuem Modell gestartet.
Starte Training...
Neues bestes Modell gespeichert mit Accuracy: 91.56%
Epoch 1/40:
Train Loss: 2.2902, Train Acc: 39.06%
Val Loss: 1.2998, Val Acc: 91.56%
LR: 0.000500
--------------------------------------------------
Neues bestes Modell gespeichert mit Accuracy: 96.93%
Epoch 2/40:
Train Loss: 1.6296, Train Acc: 64.05%
Val Loss: 1.0288, Val Acc: 96.93%
LR: 0.000500
--------------------------------------------------
Epoch 3/40:
Train Loss: 1.4668, Train Acc: 70.50%
Val Loss: 0.9412, Val Acc: 94.33%
LR: 0.000500
--------------------------------------------------
Neues bestes Modell gespeichert mit Accuracy: 97.70%
Epoch 4/40:
Train Loss: 1.3850, Train Acc: 74.37%
Val Loss: 0.9111, Val Acc: 97.70%
LR: 0.000500
--------------------------------------------------
Neues bestes Modell gespeichert mit Accuracy: 98.17%
Epoch 5/40:
Train Loss: 1.3255, Train Acc: 76.44%
Val Loss: 0.8503, Val Acc: 98.17%
LR: 0.0005