In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import pickle

In [2]:
pkl_file = '/data/common/abb_project/processed_data/total_data.pkl'

In [3]:
with open(pkl_file, 'rb') as f:
    total_data = pickle.load(f)

In [11]:
total_data[list(total_data.keys())[0]][0].shape

(1536, 1024)

In [12]:
total_data[list(total_data.keys())[0]][1].shape

(1536, 2048)

In [13]:
total_data[list(total_data.keys())[0]][2].shape

(1536, 13)

In [20]:
from torch.nn.utils.rnn import pad_sequence

def my_collate_fn(batch):
    flow_features = [item[0].clone().detach() for item in batch]
    rgb_features = [item[1].clone().detach() for item in batch]
    labels = [item[2].clone().detach() for item in batch]
    
    flow_features_padded = pad_sequence(flow_features, batch_first=True, padding_value=0)
    rgb_features_padded = pad_sequence(rgb_features, batch_first=True, padding_value=0)
    labels_padded = pad_sequence(labels, batch_first=True, padding_value=0)
    
    return flow_features_padded, rgb_features_padded, labels_padded

In [21]:
class CustomDataset(Dataset):
    def __init__(self, pkl_file):
        # with open(pkl_file, 'rb') as f:
        #     self.data_dict = pickle.load(f)
        self.data_dict = pkl_file

    def __len__(self):
        return len(self.data_dict)

    def __getitem__(self, idx):
        video_name = list(self.data_dict.keys())[idx]
        flow_features = self.data_dict[video_name][0]
        rgb_features = self.data_dict[video_name][1]
        labels = self.data_dict[video_name][2]
        
        return torch.tensor(flow_features), torch.tensor(rgb_features), torch.tensor(labels)

# DataLoader
custom_dataset = CustomDataset(total_data)
train_loader = DataLoader(dataset=custom_dataset, batch_size=4, shuffle=True, collate_fn=my_collate_fn)  # 배치 크기는 예시입니다.


In [22]:
# DataLoader를 사용하여 데이터를 불러옵니다.
for i, (rgb_features, flow_features, labels) in enumerate(train_loader):
    print(f"Batch {i+1}")
    print("RGB Features:", rgb_features.shape)  # [batch_size, num_frames, 2048] 형태를 기대합니다.
    print("Flow Features:", flow_features.shape)  # [batch_size, num_frames, 1024] 형태를 기대합니다.
    print("Labels:", labels.shape)  # [batch_size, num_frames, 13] 형태를 기대합니다.
    
    # 첫 번째 배치만 확인하고 루프를 종료합니다.
    if i == 0:
        break


Batch 1
RGB Features: torch.Size([4, 1600, 2048])
Flow Features: torch.Size([4, 1600, 1024])
Labels: torch.Size([4, 1600, 13])
