In [8]:
import torch
import glob
import os
from collections import OrderedDict, defaultdict
import math
import random
from tqdm import tqdm  # 引入 tqdm 库
import time  # 引入 time 模块
import argparse  # 引入 argparse 模块
import sys
import numpy as np
import torch.optim as optim
import torch.nn as nn
from io import BytesIO
from torch.utils.data import DataLoader, Subset, random_split
import pandas as pd

## 对原始数据集进行处理
在函数中增加了波束热力图的信息

In [9]:
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from PIL import Image

class QwenVisionDataset(Dataset):
    def __init__(self, data_csv_paths, modal='mmwave_gps', input_length=8, output_length=3):
        self.data_csv_paths = data_csv_paths
        self.modal = modal
        self.input_length = input_length
        self.output_length = output_length

        # 特征列映射
        self.features_column = {
            'rgbs': 'unit1_rgb',
            'u1_loc': 'unit1_loc',
            'u2_loc': 'unit2_loc',
            'mmwave': 'unit1_pwr_60ghz',
            'heatmap': 'unit1_mmwave_heatmap'  # 新增热力图列
        }
        
        # 初始化滑动窗口
        self.window_samples = []
        for seq_idx, data_csv_path in enumerate(self.data_csv_paths):
            data_csv = pd.read_csv(data_csv_path)
            for seq_id in data_csv['seq_index'].unique():
                seq_data = data_csv[data_csv['seq_index'] == seq_id]
                if len(seq_data) >= self.input_length:
                    for start_idx in range(len(seq_data) - self.input_length + 1):
                        self.window_samples.append((seq_idx, seq_id, start_idx))

    def __len__(self):
        return len(self.window_samples)
    
    def __getitem__(self, idx):
        seq_idx, seq_id, start_idx = self.window_samples[idx]
        base_path = os.path.dirname(self.data_csv_paths[seq_idx])
        data_csv = pd.read_csv(self.data_csv_paths[seq_idx])
        seq_data = data_csv[data_csv['seq_index'] == seq_id]

        # 获取原始路径数据
        window_data = {
            'video_paths': seq_data[self.features_column['rgbs']].iloc[
                start_idx:start_idx+self.input_length].tolist(),
            'heatmap_paths': seq_data[self.features_column['heatmap']].iloc[
                start_idx:start_idx+self.input_length].tolist()
        }

        # 处理GPS数据
        gps = []
        for i in range(self.input_length):
            u1_loc = os.path.join(base_path, seq_data[self.features_column['u1_loc']].iloc[start_idx+i])
            u2_loc = os.path.join(base_path, seq_data[self.features_column['u2_loc']].iloc[start_idx+i])
            
            with open(u1_loc, 'r') as f:
                lat1, lon1 = map(float, f.read().strip().split())
            with open(u2_loc, 'r') as f:
                lat2, lon2 = map(float, f.read().strip().split())
                
            gps.append(torch.tensor([lat2-lat1, lon2-lon1], dtype=torch.float32))
        gps = torch.stack(gps)

        # 处理mmWave数据
        mmwave = []
        for i in range(self.input_length):
            mmwave_path = os.path.join(base_path, 
                seq_data[self.features_column['mmwave']].iloc[start_idx+i])
            with open(mmwave_path, 'r') as f:
                mmwave.append(torch.tensor(
                    list(map(float, f.read().strip().split())), 
                    dtype=torch.float32))
        mmwave = torch.stack(mmwave)

        # 目标数据（最后output_length个时间步）
        target = []
        for i in range(self.input_length-self.output_length, self.input_length):
            mmwave_path = os.path.join(base_path,
                seq_data[self.features_column['mmwave']].iloc[start_idx+i])
            with open(mmwave_path, 'r') as f:
                target.append(torch.tensor(
                    list(map(float, f.read().strip().split())),
                    dtype=torch.float32))
        target = torch.stack(target)

        return {
            'video_paths': [os.path.join(base_path, p) for p in window_data['video_paths']],
            'heatmap_paths': [os.path.join(base_path, p) for p in window_data['heatmap_paths']],
            'gps': gps,
            'mmwave': mmwave,
            'target_mmwave': target
        }

def qwen_collate_fn(batch):
    collated = {
        'video_paths': [item['video_paths'] for item in batch],
        'heatmap_paths': [item['heatmap_paths'] for item in batch],
        'gps': pad_sequence([item['gps'] for item in batch], batch_first=True),
        'mmwave': pad_sequence([item['mmwave'] for item in batch], batch_first=True),
        'target_mmwave': pad_sequence([item['target_mmwave'] for item in batch], batch_first=True)
    }
    return collated

In [10]:
# 定义数据集路径
dataset_start_idx = 1
dataset_end_idx = 9
dataset_path = [f'/data2/wzj/Datasets/DeepSense/scenario{i}/' for i in range(dataset_start_idx,dataset_end_idx)]  # scenario1 ~ scenario8
# dataset_path = [f'/new_disk/yy/DeepSensePre/Data_raw/scenario{i}/' for i in range(2, 3)]  # scenario1 ~ scenario1
data_csv_paths = []
for path in dataset_path:
    data_csv_paths.extend(glob.glob(os.path.join(path, '*.csv')))

print(f"Found {len(data_csv_paths)} CSV files for training.")

Found 8 CSV files for training.


In [None]:
dataset = QwenVisionDataset(
    data_csv_paths,
    input_length=8,
    output_length=3
)

dataloader = DataLoader(
    dataset,
    batch_size=32,
    collate_fn=qwen_collate_fn
)

# 微调时可以直接将batch数据传给Qwen模型：
# model.train(batch)

In [21]:
dataset[32]

{'video_paths': ['/data2/wzj/Datasets/DeepSense/scenario1/./unit1/camera_data/image_BS1_488_00_42_18.jpg',
  '/data2/wzj/Datasets/DeepSense/scenario1/./unit1/camera_data/image_BS1_489_00_42_18.jpg',
  '/data2/wzj/Datasets/DeepSense/scenario1/./unit1/camera_data/image_BS1_490_00_42_18.jpg',
  '/data2/wzj/Datasets/DeepSense/scenario1/./unit1/camera_data/image_BS1_491_00_42_19.jpg',
  '/data2/wzj/Datasets/DeepSense/scenario1/./unit1/camera_data/image_BS1_492_00_42_19.jpg',
  '/data2/wzj/Datasets/DeepSense/scenario1/./unit1/camera_data/image_BS1_493_00_42_19.jpg',
  '/data2/wzj/Datasets/DeepSense/scenario1/./unit1/camera_data/image_BS1_494_00_42_19.jpg',
  '/data2/wzj/Datasets/DeepSense/scenario1/./unit1/camera_data/image_BS1_495_00_42_19.jpg'],
 'heatmap_paths': ['/data2/wzj/Datasets/DeepSense/scenario1/./unit1/mmWave_heatmap/mmWave_power_32.png',
  '/data2/wzj/Datasets/DeepSense/scenario1/./unit1/mmWave_heatmap/mmWave_power_33.png',
  '/data2/wzj/Datasets/DeepSense/scenario1/./unit1/mmWa