In [None]:
"""
分析所有卷积层 (Conv1, Conv2, Conv3) 的LTR-RT特征
比较不同分辨率下的特征表示
"""

import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches  # 新增导入
from sklearn import preprocessing

def tensor_to_numpy(tensor):
    """
    使用tolist()方法安全转换tensor到numpy
    """
    if isinstance(tensor, np.ndarray):
        return tensor
    if isinstance(tensor, (list, tuple)):
        return np.array(tensor)
    if isinstance(tensor, torch.Tensor):
        if tensor.requires_grad:
            tensor = tensor.detach()
        if tensor.is_cuda:
            tensor = tensor.cpu()
        return np.array(tensor.tolist())
    return np.array(tensor)

# ==================== 配置 ====================
LTR_START_ORIG = 20497
LTR_END_ORIG = 29503

# 各层的映射参数
LAYER_CONFIG = {
    'Conv1': {
        'index': 1,
        'stride': 10,
        'channels': 32,
        'name': 'Conv1 (after MaxPool 10x)'
    },
    'Conv2': {
        'index': 4,
        'stride': 150,
        'channels': 64,
        'name': 'Conv2 (after MaxPool 150x)'
    },
    'Conv3': {
        'index': 7,
        'stride': 2250,
        'channels': 128,
        'name': 'Conv3 (after MaxPool 2250x)'
    }
}

# 定义各层的 LTR-RT 区间
LTR_INTERVALS = {
    'Conv1': {
        'all_ltr': (20497, 29484)
    },
    'Conv2': {
        'all_ltr': (2050, 2929)
    },
    'Conv3': {
        'all_ltr': (137, 160)
    }
}

# ==================== 加载数据 ====================
print("加载数据...")
mid_res = torch.load('tensor_data.pt')

print(f"\nmid_res类型: {type(mid_res)}")
print(f"mid_res长度: {len(mid_res)}")
print("\n前几个元素的信息:")
for i in range(min(10, len(mid_res))):
    item = mid_res[i]
    if isinstance(item, torch.Tensor):
        print(f" [{i}] Tensor: shape={item.shape}, device={item.device}")
    elif isinstance(item, (list, tuple)):
        print(f" [{i}] List/Tuple: 长度={len(item)}")
        if len(item) > 0 and isinstance(item[0], torch.Tensor):
            print(f"     第一个元素: shape={item[0].shape}, device={item[0].device}")
    else:
        print(f" [{i}] 其他类型: {type(item)}")

# ==================== 分析函数 ====================
def analyze_layer(layer_name, layer_config, mid_res):
    """分析单个卷积层"""
    print(f"\n{'='*60}")
    print(f"分析 {layer_name}")
    print(f"{'='*60}")
    try:
        layer_output = mid_res[layer_config['index']]
        print(f"原始数据类型: {type(layer_output)}")
        if isinstance(layer_output, (list, tuple)):
            print(f" 是序列类型，长度={len(layer_output)}")
            tensor = layer_output[0]
        else:
            tensor = layer_output
        print(f"Tensor类型: {type(tensor)}")
        print(f"Tensor形状: {tensor.shape}")
        print(f"Tensor设备: {tensor.device}")
        print(f"需要梯度: {tensor.requires_grad}")
        
        print("正在转换为numpy...")
        data = tensor_to_numpy(tensor)
        print(f"✅ 转换成功！numpy形状: {data.shape}")
        
        print(f"原始形状: {data.shape}")
        while data.ndim > 2:
            if 1 in data.shape:
                data = np.squeeze(data)
                print(f" squeeze后: {data.shape}")
            else:
                if data.shape[0] == 1:
                    data = data[0]
                    print(f" 取第一个batch: {data.shape}")
                else:
                    break
        
        if data.ndim == 3:
            data = data.squeeze()
            print(f" 最终squeeze: {data.shape}")
        print(f"✅ 最终形状: {data.shape}")
        
        stride = layer_config['stride']
        ltr_start = LTR_START_ORIG // stride
        ltr_end = LTR_END_ORIG // stride
        ltr_width = ltr_end - ltr_start
        
        print(f"\n特征图形状: {data.shape}")
        print(f"LTR-RT原始位置: {LTR_START_ORIG} - {LTR_END_ORIG} bp ({LTR_END_ORIG - LTR_START_ORIG} bp)")
        print(f"LTR-RT特征位置: {ltr_start} - {ltr_end} ({ltr_width} 个位置)")
        print(f"分辨率: 每个特征位置 ≈ {stride} bp")
        
        print("正在归一化...")
        scaler = preprocessing.RobustScaler()
        data_norm = scaler.fit_transform(data.T).T
        print("✅ 归一化完成")
        
        print("计算选择性指数...")
        ltr_region = data[:, ltr_start:ltr_end]
        bg_indices = list(range(0, ltr_start)) + list(range(ltr_end, data.shape[1]))
        bg_region = data[:, bg_indices]
        
        ltr_activation = ltr_region.mean(axis=1)
        bg_activation = bg_region.mean(axis=1)
        selectivity = (ltr_activation - bg_activation) / (ltr_activation + bg_activation + 1e-8)
        
        ltr_selective = np.sum(selectivity > 0.2)
        bg_selective = np.sum(selectivity < -0.2)
        neutral = np.sum(np.abs(selectivity) <= 0.2)
        
        print(f"\n选择性统计:")
        print(f" LTR特异性通道 (SI > 0.2): {ltr_selective} ({100*ltr_selective/len(selectivity):.1f}%)")
        print(f" 背景特异性通道 (SI < -0.2): {bg_selective} ({100*bg_selective/len(selectivity):.1f}%)")
        print(f" 中性通道: {neutral} ({100*neutral/len(selectivity):.1f}%)")
        
        top_ch = np.argsort(np.abs(selectivity))[-5:][::-1]
        print(f"\nTop 5 判别通道:")
        for i, ch in enumerate(top_ch, 1):
            print(f" {i}. Channel {ch}: SI = {selectivity[ch]:+.3f}")
        
        return {
            'data': data,
            'data_norm': data_norm,
            'ltr_start': ltr_start,
            'ltr_end': ltr_end,
            'selectivity': selectivity,
            'top_channels': top_ch,
            'stats': {
                'ltr_selective': ltr_selective,
                'bg_selective': bg_selective,
                'neutral': neutral
            }
        }
    except Exception as e:
        print(f"\n❌ 错误: {e}")
        print(f"错误类型: {type(e).__name__}")
        import traceback
        traceback.print_exc()
        return None

# ==================== 分析所有层 ====================
results = {}
for layer_name, config in LAYER_CONFIG.items():
    result = analyze_layer(layer_name, config, mid_res)
    if result is not None:
        results[layer_name] = result
    else:
        print(f"\n⚠️ {layer_name} 分析失败，跳过")

if len(results) == 0:
    print("\n❌ 所有层都分析失败！")
    print("请检查数据格式和索引是否正确")
    exit(1)

print(f"\n✅ 成功分析了 {len(results)} 个层")

加载数据...

mid_res类型: <class 'list'>
mid_res长度: 17

前几个元素的信息:
 [0] Tensor: shape=torch.Size([1, 1, 5, 50000]), device=cpu
 [1] Tensor: shape=torch.Size([1, 32, 1, 49981]), device=cpu
 [2] Tensor: shape=torch.Size([1, 32, 1, 49981]), device=cpu
 [3] Tensor: shape=torch.Size([1, 32, 1, 4998]), device=cpu
 [4] Tensor: shape=torch.Size([1, 64, 1, 4979]), device=cpu
 [5] Tensor: shape=torch.Size([1, 64, 1, 4979]), device=cpu
 [6] Tensor: shape=torch.Size([1, 64, 1, 331]), device=cpu
 [7] Tensor: shape=torch.Size([1, 128, 1, 297]), device=cpu
 [8] Tensor: shape=torch.Size([1, 128, 1, 297]), device=cpu
 [9] Tensor: shape=torch.Size([1, 128, 1, 19]), device=cpu

分析 Conv1
原始数据类型: <class 'torch.Tensor'>
Tensor类型: <class 'torch.Tensor'>
Tensor形状: torch.Size([1, 32, 1, 49981])
Tensor设备: cpu
需要梯度: False
正在转换为numpy...
✅ 转换成功！numpy形状: (1, 32, 1, 49981)
原始形状: (1, 32, 1, 49981)
 squeeze后: (32, 49981)
✅ 最终形状: (32, 49981)

特征图形状: (32, 49981)
LTR-RT原始位置: 20497 - 29503 bp (9006 bp)
LTR-RT特征位置: 2049 - 2950 (9

In [None]:
# ==================== 生成带高亮的热图 ====================
print(f"\n{'='*60}")
print("生成带 LTR-RT 区间高亮的热图...")
print(f"{'='*60}\n")

for layer_name, result in results.items():
    data_width = result['data'].shape[1]
    data_channels = result['data'].shape[0]
    
    fig_width = 12  # 加宽以容纳图例
    fig_height = 10
    fig, ax = plt.subplots(figsize=(fig_width, fig_height))
    
    # 绘制热图
    im = ax.imshow(result['data_norm'], aspect='auto', cmap='RdYlBu_r',
                   interpolation='nearest', vmin=-2, vmax=2)
    
    # LTR-RT 区间高亮
    if layer_name in LTR_INTERVALS:
        intervals = LTR_INTERVALS[layer_name]
        
        # 1. 绘制 LTR-RT 区间（绿色实线框）
        all_start, all_end = intervals['all_ltr']
        all_width = all_end - all_start
        
        all_rect = patches.Rectangle(
            (all_start - 0.5, -0.5),
            all_width,
            data_channels,
            linewidth=2.5,
            edgecolor='limegreen',
            facecolor='none',
            linestyle='-',
            label='LTR-RT',
            zorder=10
        )
        ax.add_patch(all_rect)
    
    # 计算层描述
    if layer_name == 'Conv1':
        layer_desc = 'Conv1'
    elif layer_name == 'Conv2':
        layer_desc = 'Conv2'
    else:
        layer_desc = 'Conv3'
    
    # 标题
    ax.set_title(f"{layer_desc} with LTR-RT Regions\n"
                 f"Shape: {data_channels} channels × {data_width} positions",
                 fontsize=30, fontweight='bold', pad=15)
    ax.set_xlabel('Feature Map Position', fontsize=25)
    ax.set_ylabel('Channel', fontsize=25)
    ax.tick_params(axis='both', labelsize=25)
    # 颜色条
    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('Normalized Activation', rotation=270, labelpad=20, fontsize=25)
    
    # 添加图例
    handles, labels = ax.get_legend_handles_labels()
    if labels:
        by_label = dict(zip(labels, handles))  # 去重
        ax.legend(by_label.values(), by_label.keys(), 
                 loc='upper right', fontsize=25, framealpha=0.9)
    
    # 网格
    if data_channels <= 64:
        ytick_step = max(1, data_channels // 16)
    else:
        ytick_step = max(1, data_channels // 20)
    ax.set_yticks(np.arange(0, data_channels, ytick_step))
    ax.grid(True, alpha=0.2, axis='y', linestyle=':')
    
    plt.tight_layout()
    
    # 保存
    filename = f'{layer_name}_heatmap_highlighted.png'
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    print(f"✅ {layer_name} 高亮热图已保存: {filename}")
    print(f" - 图片尺寸: {fig_width}×{fig_height} inches")
    print(f" - 特征图: {data_channels} channels × {data_width} positions")
    
    # 打印区间信息
    if layer_name in LTR_INTERVALS:
        intervals = LTR_INTERVALS[layer_name]
        all_start, all_end = intervals['all_ltr']
        print(f" - LTR-RT: [{all_start}, {all_end}] ({all_end-all_start} pos)")
    
    print()
    plt.close()

print(f"{'='*60}")
print("✅ 完成！所有带高亮的热图已生成")
print(f"{'='*60}")
print("\n图例说明:")
print("  🔴 红色实线框: LTR-RT 区间 (LTR-RT完全在卷积核内)\n")


生成带 LTR-RT 区间高亮的热图...

✅ Conv1 高亮热图已保存: Conv1_heatmap_highlighted.png
 - 图片尺寸: 12×10 inches
 - 特征图: 32 channels × 49981 positions
 - LTR-RT: [20497, 29484] (8987 pos)

✅ Conv2 高亮热图已保存: Conv2_heatmap_highlighted.png
 - 图片尺寸: 12×10 inches
 - 特征图: 64 channels × 4979 positions
 - LTR-RT: [2050, 2929] (879 pos)

✅ Conv3 高亮热图已保存: Conv3_heatmap_highlighted.png
 - 图片尺寸: 12×10 inches
 - 特征图: 128 channels × 297 positions
 - LTR-RT: [137, 160] (23 pos)

✅ 完成！所有带高亮的热图已生成

图例说明:
  🔴 红色实线框: All-LTR-RT 区间 (卷积核完全在LTR-RT内)

