In [None]:
import wandb
import torch

# Initialize wandb
wandb.login()

# Download CNN model
run = wandb.init(project="facial-expression-recognition")
cnn_artifact = run.use_artifact('ellekvirikashvili-free-university-of-tbilisi-/facial-expression-recognition/run-30aso492-history:v0')
cnn_dir = cnn_artifact.download()

# Download ViT model  
vit_artifact = run.use_artifact('ellekvirikashvili-free-university-of-tbilisi-/facial-expression-recognition/vit-fer2013-final-model:v0')
vit_dir = vit_artifact.download()

In [None]:
# Load your models (adjust based on how you saved them)
cnn_model = torch.load(f"{cnn_dir}/model.pth")  # or however you saved it
vit_model = torch.load(f"{vit_dir}/model.pth")

# Set to evaluation mode
cnn_model.eval()
vit_model.eval()

In [None]:
def ensemble_predict(cnn_model, vit_model, x, weights=[0.5, 0.5]):
    """
    Ensemble prediction combining CNN and ViT models
    """
    with torch.no_grad():
        # Get predictions from both models
        cnn_pred = torch.softmax(cnn_model(x), dim=1)
        vit_pred = torch.softmax(vit_model(x), dim=1)
        
        # Weighted average
        ensemble_pred = weights[0] * cnn_pred + weights[1] * vit_pred
        
    return ensemble_pred, cnn_pred, vit_pred

In [None]:
from sklearn.metrics import accuracy_score
import numpy as np

def find_optimal_weights(cnn_preds, vit_preds, true_labels):
    """Find optimal weights for ensemble"""
    best_acc = 0
    best_weights = [0.5, 0.5]
    
    for w1 in np.arange(0.1, 1.0, 0.1):
        w2 = 1 - w1
        ensemble_pred = w1 * cnn_preds + w2 * vit_preds
        pred_labels = torch.argmax(ensemble_pred, dim=1)
        acc = accuracy_score(true_labels, pred_labels)
        
        if acc > best_acc:
            best_acc = acc
            best_weights = [w1, w2]
    
    return best_weights, best_acc

In [None]:
import wandb
import torch
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
from torch.utils.data import DataLoader

class FERensemble:
    def __init__(self, project_name="facial-expression-recognition"):
        self.project_name = project_name
        self.cnn_model = None
        self.vit_model = None
        self.weights = [0.5, 0.5]  # Default equal weights
        
    def download_models(self):
        """Download both models from W&B"""
        print("Downloading models from W&B...")
        
        # Initialize wandb
        run = wandb.init(project=self.project_name)
        
        # Download CNN model
        print("Downloading CNN model...")
        cnn_artifact = run.use_artifact(
            'ellekvirikashvili-free-university-of-tbilisi-/facial-expression-recognition/run-30aso492-history:v0'
        )
        cnn_dir = cnn_artifact.download()
        
        # Download ViT model
        print("Downloading ViT model...")
        vit_artifact = run.use_artifact(
            'ellekvirikashvili-free-university-of-tbilisi-/facial-expression-recognition/vit-fer2013-final-model:v0',
            type='model'
        )
        vit_dir = vit_artifact.download()
        
        return cnn_dir, vit_dir
    
    def load_models(self, cnn_path, vit_path, device='cuda' if torch.cuda.is_available() else 'cpu'):
        """Load the downloaded models"""
        print(f"Loading models on {device}...")
        
        # Load CNN model
        self.cnn_model = torch.load(cnn_path, map_location=device)
        self.cnn_model.eval()
        
        # Load ViT model
        self.vit_model = torch.load(vit_path, map_location=device)
        self.vit_model.eval()
        
        print("Models loaded successfully!")
        
    def predict_single(self, x, return_individual=False):
        """Make prediction on a single batch"""
        with torch.no_grad():
            # Get predictions from both models
            cnn_logits = self.cnn_model(x)
            vit_logits = self.vit_model(x)
            
            # Apply softmax to get probabilities
            cnn_probs = F.softmax(cnn_logits, dim=1)
            vit_probs = F.softmax(vit_logits, dim=1)
            
            # Ensemble prediction (weighted average)
            ensemble_probs = self.weights[0] * cnn_probs + self.weights[1] * vit_probs
            
            if return_individual:
                return ensemble_probs, cnn_probs, vit_probs
            return ensemble_probs
    
    def evaluate_on_dataset(self, dataloader, device='cuda' if torch.cuda.is_available() else 'cpu'):
        """Evaluate ensemble on a dataset"""
        all_ensemble_preds = []
        all_cnn_preds = []
        all_vit_preds = []
        all_labels = []
        
        print("Evaluating on dataset...")
        
        for batch_idx, (data, labels) in enumerate(dataloader):
            data, labels = data.to(device), labels.to(device)
            
            # Get predictions
            ensemble_probs, cnn_probs, vit_probs = self.predict_single(data, return_individual=True)
            
            # Store predictions
            all_ensemble_preds.append(ensemble_probs.cpu())
            all_cnn_preds.append(cnn_probs.cpu())
            all_vit_preds.append(vit_probs.cpu())
            all_labels.append(labels.cpu())
            
            if batch_idx % 50 == 0:
                print(f"Processed {batch_idx} batches...")
        
        # Concatenate all predictions
        all_ensemble_preds = torch.cat(all_ensemble_preds, dim=0)
        all_cnn_preds = torch.cat(all_cnn_preds, dim=0)
        all_vit_preds = torch.cat(all_vit_preds, dim=0)
        all_labels = torch.cat(all_labels, dim=0)
        
        return all_ensemble_preds, all_cnn_preds, all_vit_preds, all_labels
    
    def optimize_weights(self, val_dataloader, device='cuda' if torch.cuda.is_available() else 'cpu'):
        """Find optimal ensemble weights using validation set"""
        print("Finding optimal ensemble weights...")
        
        # Get predictions on validation set
        _, cnn_preds, vit_preds, true_labels = self.evaluate_on_dataset(val_dataloader, device)
        
        best_acc = 0
        best_weights = [0.5, 0.5]
        
        # Grid search for optimal weights
        for w1 in np.arange(0.1, 1.0, 0.1):
            w2 = 1 - w1
            
            # Calculate ensemble predictions with these weights
            ensemble_pred = w1 * cnn_preds + w2 * vit_preds
            pred_labels = torch.argmax(ensemble_pred, dim=1)
            
            # Calculate accuracy
            acc = accuracy_score(true_labels.numpy(), pred_labels.numpy())
            
            if acc > best_acc:
                best_acc = acc
                best_weights = [w1, w2]
        
        self.weights = best_weights
        print(f"Optimal weights found: CNN={best_weights[0]:.2f}, ViT={best_weights[1]:.2f}")
        print(f"Best validation accuracy: {best_acc:.4f}")
        
        return best_weights, best_acc
    
    def get_metrics(self, predictions, true_labels, class_names=None):
        """Calculate comprehensive metrics"""
        pred_labels = torch.argmax(predictions, dim=1)
        
        # Basic accuracy
        accuracy = accuracy_score(true_labels.numpy(), pred_labels.numpy())
        
        # Classification report
        if class_names is None:
            class_names = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']
        
        report = classification_report(
            true_labels.numpy(), 
            pred_labels.numpy(), 
            target_names=class_names,
            output_dict=True
        )
        
        return accuracy, report
    
    def compare_models(self, test_dataloader, device='cuda' if torch.cuda.is_available() else 'cpu'):
        """Compare individual models vs ensemble"""
        print("Comparing model performances...")
        
        ensemble_preds, cnn_preds, vit_preds, true_labels = self.evaluate_on_dataset(test_dataloader, device)
        
        # Calculate accuracies
        ensemble_acc, ensemble_report = self.get_metrics(ensemble_preds, true_labels)
        cnn_acc, cnn_report = self.get_metrics(cnn_preds, true_labels)
        vit_acc, vit_report = self.get_metrics(vit_preds, true_labels)
        
        print("\n" + "="*50)
        print("MODEL COMPARISON RESULTS")
        print("="*50)
        print(f"CNN Accuracy:      {cnn_acc:.4f}")
        print(f"ViT Accuracy:      {vit_acc:.4f}")
        print(f"Ensemble Accuracy: {ensemble_acc:.4f}")
        print(f"Improvement:       {ensemble_acc - max(cnn_acc, vit_acc):.4f}")
        print("="*50)
        
        # Log to W&B
        wandb.log({
            "cnn_accuracy": cnn_acc,
            "vit_accuracy": vit_acc,
            "ensemble_accuracy": ensemble_acc,
            "ensemble_improvement": ensemble_acc - max(cnn_acc, vit_acc),
            "optimal_cnn_weight": self.weights[0],
            "optimal_vit_weight": self.weights[1]
        })
        
        return {
            'ensemble': {'accuracy': ensemble_acc, 'report': ensemble_report},
            'cnn': {'accuracy': cnn_acc, 'report': cnn_report},
            'vit': {'accuracy': vit_acc, 'report': vit_report}
        }

# Example usage
def main():
    # Initialize ensemble
    ensemble = FERensemble()
    
    # Download models
    cnn_dir, vit_dir = ensemble.download_models()
    
    # Load models (adjust paths based on your saved model structure)
    cnn_model_path = f"{cnn_dir}/model.pth"  # or wherever your .pth file is
    vit_model_path = f"{vit_dir}/model.pth"
    
    ensemble.load_models(cnn_model_path, vit_model_path)
    
    # Assuming you have your dataloaders ready
    # val_dataloader = your_validation_dataloader
    # test_dataloader = your_test_dataloader
    
    # Optimize ensemble weights on validation set
    # ensemble.optimize_weights(val_dataloader)
    
    # Compare models on test set
    # results = ensemble.compare_models(test_dataloader)
    
    print("Ensemble setup complete!")

if __name__ == "__main__":
    main()

: 