In [None]:
import astra
import traceback

In [2]:
# Add test function for ASTRA algorithms debugging
def test_astra_algorithms(sinogram=None, test_size=256):
    """
    Test function to debug ASTRA algorithms in Jupyter notebooks
    
    Parameters:
    -----------
    sinogram : ndarray, optional
        Test sinogram to use. If None, creates a simple test sinogram
    test_size : int
        Size of the test data if creating a test sinogram
        
    Returns:
    --------
    results : dict
        Dictionary with test results for each algorithm and error messages if any
    """
    results = {
        'available_algorithms': [],
        'tested_algorithms': {},
        'errors': {},
        'success': {}
    }
    
    try:
        # Get list of available algorithms
        alg_list = astra.astra.algorithm_list()
        results['available_algorithms'] = alg_list
        print(f"Available ASTRA algorithms: {alg_list}")
    except Exception as e:
        print(f"Error getting algorithm list: {e}")
        results['errors']['list_algorithms'] = str(e)
        results['errors']['list_algorithms_traceback'] = traceback.format_exc()
    
    # Create test sinogram if not provided
    if sinogram is None:
        # Create a simple phantom
        phantom = np.zeros((test_size, test_size), dtype=np.float32)
        center = test_size // 2
        radius = test_size // 4
        y, x = np.ogrid[-center:test_size-center, -center:test_size-center]
        mask = x*x + y*y <= radius*radius
        phantom[mask] = 1.0
        
        # Create angles
        num_angles = 180
        angles = np.linspace(0, np.pi, num_angles, endpoint=False)
        
        # Create projection geometry
        proj_geom = astra.create_proj_geom('parallel', 1.0, test_size, angles)
        vol_geom = astra.create_vol_geom(test_size, test_size)
        
        # Create forward projection
        phantom_id = astra.data2d.create('-vol', vol_geom, phantom)
        sinogram_id = astra.data2d.create('-sino', proj_geom)
        
        # Forward project
        try:
            fp_id = astra.algorithm.create({'type': 'FP', 
                                         'ProjectionDataId': sinogram_id, 
                                         'VolumeDataId': phantom_id})
            astra.algorithm.run(fp_id)
            sinogram = astra.data2d.get(sinogram_id)
            results['success']['forward_projection'] = True
            
            # Clean up
            astra.algorithm.delete(fp_id)
            astra.data2d.delete([phantom_id, sinogram_id])
            
            print("Created test sinogram successfully")
        except Exception as e:
            print(f"Error creating test sinogram: {e}")
            results['errors']['create_sinogram'] = str(e)
            results['errors']['create_sinogram_traceback'] = traceback.format_exc()
            return results
    
    # Get sinogram dimensions
    num_angles, num_detector_pixels = sinogram.shape
    print(f"Sinogram shape: {sinogram.shape}")
    
    # Test each reconstruction algorithm
    algorithms_to_test = [
        ('BP', 'BP'),
        ('BP_CUDA', 'BP_CUDA'),
        ('FBP', 'FBP'),
        ('FBP_CUDA', 'FBP_CUDA'),
        ('SIRT', 'SIRT'),
        ('SIRT_CUDA', 'SIRT_CUDA'),
        ('SART', 'SART'),
        ('SART_CUDA', 'SART_CUDA'),
        ('CGLS', 'CGLS'),
        ('CGLS_CUDA', 'CGLS_CUDA')
    ]
    
    # Create reconstruction geometry
    angles = np.linspace(0, np.pi, num_angles, endpoint=False)
    proj_geom = astra.create_proj_geom('parallel', 1.0, num_detector_pixels, angles)
    vol_geom = astra.create_vol_geom(num_detector_pixels, num_detector_pixels)
    
    # Create data objects
    sinogram_id = astra.data2d.create('-sino', proj_geom, sinogram)
    rec_id = astra.data2d.create('-vol', vol_geom)
    
    # Test each algorithm
    for alg_name, alg_type in algorithms_to_test:
        print(f"Testing {alg_name}...")
        try:
            # Configure the algorithm
            cfg = astra.astra_dict(alg_type)
            cfg['ProjectionDataId'] = sinogram_id
            cfg['ReconstructionDataId'] = rec_id
            
            # Add options for iterative methods
            if alg_type in ['SIRT', 'SIRT_CUDA', 'SART', 'SART_CUDA', 'CGLS', 'CGLS_CUDA']:
                cfg['option'] = {'ProjectionOrder': 'random'}
                if alg_type in ['SIRT', 'SIRT_CUDA']:
                    cfg['option']['MinConstraint'] = 0
                    cfg['option']['MaxConstraint'] = 255
            
            # Create and run the algorithm
            alg_id = astra.algorithm.create(cfg)
            
            # Run algorithm (with iterations for iterative methods)
            if alg_type in ['SIRT', 'SIRT_CUDA', 'SART', 'SART_CUDA', 'CGLS', 'CGLS_CUDA']:
                astra.algorithm.run(alg_id, 20)  # 20 iterations
            else:
                astra.algorithm.run(alg_id)
            
            # Get the result and store it
            reconstruction = astra.data2d.get(rec_id)
            results['tested_algorithms'][alg_name] = {
                'success': True,
                'shape': reconstruction.shape,
                'min': float(reconstruction.min()),
                'max': float(reconstruction.max()),
                'mean': float(reconstruction.mean())
            }
            results['success'][alg_name] = True
            
            print(f"  ✓ Success: {alg_name} worked")
            
            # Clean up
            astra.algorithm.delete(alg_id)
            
        except Exception as e:
            results['tested_algorithms'][alg_name] = {
                'success': False,
                'error': str(e)
            }
            results['errors'][alg_name] = str(e)
            results['errors'][f"{alg_name}_traceback"] = traceback.format_exc()
            print(f"  ✗ Error with {alg_name}: {e}")
    
    # Clean up data objects
    astra.data2d.delete([sinogram_id, rec_id])
    
    # Print summary
    successful = [k for k, v in results['tested_algorithms'].items() if v.get('success', False)]
    failed = [k for k, v in results['tested_algorithms'].items() if not v.get('success', False)]
    
    print("\nSummary:")
    print(f"Successfully tested {len(successful)} algorithms: {', '.join(successful)}")
    print(f"Failed to test {len(failed)} algorithms: {', '.join(failed)}")
    
    return results

In [None]:
results = test_astra_algorithms()