# PyTorch Tensor Explorer

This notebook provides utilities for exploring and analyzing PyTorch tensor files (.pt) containing single tensors or tuples of tensors.
Make sure you have PyTorch installed with CUDA support.

In [17]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Union, Tuple, List, Dict, Any

# Check CUDA availability and version
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"cuDNN version: {torch.backends.cudnn.version()}")
    print(f"GPU device: {torch.cuda.get_device_name(0)}")

CUDA available: True
CUDA version: 12.1
cuDNN version: 8902
GPU device: NVIDIA GeForce GTX 1050 Ti with Max-Q Design


## Utility Functions for Tensor Analysis

In [18]:
def get_tensor_info(tensor: Union[torch.Tensor, Any]) -> Dict:
    """Extract basic information from a tensor or other object."""
    if torch.is_tensor(tensor):
        return {
            'type': type(tensor),
            'device': tensor.device,
            'dtype': tensor.dtype,
            'shape': tensor.shape,
            'requires_grad': tensor.requires_grad
        }
    else:
        return {
            'type': type(tensor),
            'value': str(tensor) if not isinstance(tensor, (dict, list, tuple)) else f"{type(tensor)} of length {len(tensor)}"
        }

def load_tensor_file(file_path: str) -> Tuple[Any, List[Dict]]:
    """Load a .pt file and return the content with information about each element.
    
    Args:
        file_path (str): Path to the .pt file
        
    Returns:
        tuple: (loaded_content, list of info dictionaries for each element)
    """
    content = torch.load(file_path)
    
    # Handle different types of content
    if isinstance(content, tuple):
        infos = [get_tensor_info(item) for item in content]
    else:
        infos = [get_tensor_info(content)]
        content = (content,)  # Convert to tuple for consistent handling
    
    return content, infos

def analyze_tensor_statistics(tensor: torch.Tensor) -> Union[Dict, str]:
    """Compute basic statistics for a tensor."""
    if not torch.is_tensor(tensor):
        return "Input is not a tensor"
    
    # Move tensor to CPU for analysis if needed
    tensor_cpu = tensor.cpu()
    
    stats = {
        'min': tensor_cpu.min().item(),
        'max': tensor_cpu.max().item(),
        'mean': tensor_cpu.mean().item(),
        'std': tensor_cpu.std().item(),
        'num_zeros': (tensor_cpu == 0).sum().item(),
        'num_non_zeros': (tensor_cpu != 0).sum().item(),
        'total_elements': tensor_cpu.numel(),
        'sparsity': (tensor_cpu == 0).sum().item() / tensor_cpu.numel()
    }
    
    return stats

def visualize_tensor_distribution(tensor: torch.Tensor, bins: int = 50, title: str = None):
    """Plot histogram of tensor values."""
    if not torch.is_tensor(tensor):
        print("Input is not a tensor")
        return
    
    plt.figure(figsize=(10, 6))
    tensor_np = tensor.cpu().detach().numpy()
    plt.hist(tensor_np.flatten(), bins=bins, density=True)
    plt.title(title or 'Tensor Value Distribution')
    plt.xlabel('Value')
    plt.ylabel('Density')
    plt.grid(True)
    plt.show()

def move_to_device(obj: Any, device: str = 'cuda') -> Any:
    """Recursively move tensor(s) to specified device."""
    if device == 'cuda' and not torch.cuda.is_available():
        print("CUDA is not available, keeping on CPU")
        return obj
    
    if torch.is_tensor(obj):
        return obj.to(device)
    elif isinstance(obj, tuple):
        return tuple(move_to_device(item, device) for item in obj)
    elif isinstance(obj, list):
        return [move_to_device(item, device) for item in obj]
    elif isinstance(obj, dict):
        return {k: move_to_device(v, device) for k, v in obj.items()}
    else:
        return obj

## Example Usage

Replace 'path/to/your/tensor.pt' with the actual path to your tensor file.

In [21]:
# Load and analyze tensor file
file_path = '/home/tc/git_repo/mls/explore_mas/tr/Simple64_2021-11-21-11-55-57.pt'
tensors, tensor_infos = load_tensor_file(file_path)

# Print information about each tensor/item in the file
print("File Contents:")
for idx, (tensor, info) in enumerate(zip(tensors, tensor_infos)):
    print(f"\nItem {idx}:")
    for key, value in info.items():
        print(f"{key}: {value}")
    
    # If it's a tensor, compute and show statistics
    if torch.is_tensor(tensor):
        print("\nStatistics:")
        stats = analyze_tensor_statistics(tensor)
        for key, value in stats.items():
            print(f"{key}: {value}")
        
        # Visualize distribution
        # visualize_tensor_distribution(tensor, title=f'Distribution of Tensor {idx}')

File Contents:

Item 0:
type: <class 'torch.Tensor'>
device: cpu
dtype: torch.float64
shape: torch.Size([5, 1056596])
requires_grad: False

Statistics:
min: 0.0
max: 200.0
mean: 0.032278188002712926
std: 0.31135553441017716
num_zeros: 5119785
num_non_zeros: 163195
total_elements: 5282980
sparsity: 0.9691092905897808

Item 1:
type: <class 'torch.Tensor'>
device: cpu
dtype: torch.float64
shape: torch.Size([5, 11446])
requires_grad: False

Statistics:
min: 0.0
max: 1.0
mean: 0.0014852350165996855
std: 0.038510453223615036
num_zeros: 57145
num_non_zeros: 85
total_elements: 57230
sparsity: 0.9985147649834003


## Memory Management

In [None]:
def print_memory_stats():
    """Print current GPU memory usage."""
    if torch.cuda.is_available():
        print(f"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
        print(f"Cached: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

def clear_gpu_memory():
    """Clear GPU memory cache."""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("GPU memory cache cleared")

# Example usage
print("Before clearing memory:")
print_memory_stats()

clear_gpu_memory()

print("\nAfter clearing memory:")
print_memory_stats()

## Batch Processing Multiple Tensor Files

In [22]:
def process_tensor_directory(directory_path: str) -> Dict:
    """Process all .pt files in a directory."""
    tensor_files = list(Path(directory_path).glob('*.pt'))
    results = {}
    
    for file_path in tensor_files:
        print(f"\nProcessing {file_path.name}")
        try:
            tensors, infos = load_tensor_file(file_path)
            
            file_results = []
            for idx, (tensor, info) in enumerate(zip(tensors, infos)):
                item_result = {'info': info}
                
                if torch.is_tensor(tensor):
                    item_result['stats'] = analyze_tensor_statistics(tensor)
                
                file_results.append(item_result)
            
            results[file_path.name] = file_results
            
            # Clear memory
            del tensors
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                
        except Exception as e:
            print(f"Error processing {file_path.name}: {str(e)}")
    
    return results

# Example usage
directory_path = '/home/tc/git_repo/mls/explore_mas/tr/'
results = process_tensor_directory(directory_path)

print("\nProcessing Results:")
for file_name, file_results in results.items():
    print(f"\nFile: {file_name}")
    for idx, item_result in enumerate(file_results):
        print(f"\nItem {idx}:")
        print("Info:")
        for key, value in item_result['info'].items():
            print(f"  {key}: {value}")
        if 'stats' in item_result:
            print("Statistics:")
            for key, value in item_result['stats'].items():
                print(f"  {key}: {value}")


Processing Simple64_2021-12-22-03-55-31.pt

Processing Simple64_2021-12-06-06-11-43.pt

Processing Simple64_2021-12-18-01-28-49.pt

Processing Simple64_2021-12-28-07-44-10.pt

Processing Simple64_2021-11-21-12-07-13.pt

Processing Simple64_2021-12-18-13-20-03.pt

Processing Simple64_2021-12-06-06-14-22.pt

Processing Simple64_2021-12-19-13-31-35.pt

Processing Simple64_2021-11-21-12-09-44.pt

Processing Simple64_2021-12-18-13-23-13.pt

Processing Simple64_2021-12-18-01-31-59.pt

Processing Simple64_2021-12-14-09-33-00.pt

Processing Simple64_2021-12-14-09-20-58.pt

Processing Simple64_2021-12-15-08-13-39.pt

Processing Simple64_2021-11-21-13-33-16.pt

Processing Simple64_2021-12-22-04-07-18.pt

Processing Simple64_2021-12-17-12-49-33.pt

Processing Simple64_2021-12-06-06-32-08.pt

Processing Simple64_2021-12-03-03-31-59.pt

Processing Simple64_2021-12-22-04-00-23.pt

Processing Simple64_2021-12-18-01-30-22.pt

Processing Simple64_2021-12-06-06-25-18.pt

Processing Simple64_2021-12-14-