In [4]:
import torch
import torch.nn as nn

def post_processing_module(logits, input_ids, device, sentence_ps='None', window_size=3):
    pad_token_id = 0
    seq_split_id_list = [42000, 42001, 42002, 42003]
    # 여러 패딩 토큰 ID를 고려한 불리언 마스크 생성
    mask = torch.isin(input_ids, torch.tensor(seq_split_id_list, device=device))
    print(mask.shape, mask)
    # True 값의 인덱스를 리스트 형태로 추출
    true_indices = [torch.where(mask[b])[0].tolist() for b in range(batch_size)]
    print(true_indices)
    # 유효한 시퀀스 길이를 구함
    sequence_lengths = (torch.eq(input_ids, pad_token_id).int().argmax(-1) - 1) % input_ids.shape[-1]


    def segment_average_torch(input_tensor, true_indices, sequence_lengths):
        batch_size, seq_len, class_num = input_tensor.shape
        segment_avg_list = []

        for b in range(batch_size):
            start_idx = 0
            segments = true_indices[b]
            
            batch_segment_avg = []
            for end_idx in segments:

                if start_idx < end_idx:  # Ensure there is a valid segment
                    segment = input_tensor[b, start_idx:end_idx, :]  # 문장 단위
                    
                    # Calculate segment average for the segment
                    avg_value = segment.mean(dim=0)

                    batch_segment_avg.append(avg_value)
                
                start_idx = end_idx  # Move to the next segment
            
                if end_idx < sequence_lengths[b]:
                    segment = input_tensor[b, start_idx:sequence_lengths[b]+1, :]
                    avg_value = segment.mean(dim=0)
                    batch_segment_avg.append(avg_value)

            segment_avg_list.append(torch.stack(batch_segment_avg))
        
        return segment_avg_list

    def moving_average_segments(segment_avg_logits, window_size):
        moving_avg_segments = []

        for batch in segment_avg_logits:
            moving_avg_batch = []
            num_segments = batch.shape[0]
            
            for i in range(num_segments):
                start = max(0, i - window_size + 1)
                end = i + 1
                window_avg = batch[start:end].mean(dim=0)
                moving_avg_batch.append(window_avg)
            
            moving_avg_segments.append(torch.stack(moving_avg_batch))
        
        return moving_avg_segments

    if sentence_ps == 'None':
        # 세그먼트 평균 계산
        seq_logits = segment_average_torch(logits, true_indices, sequence_lengths)

    # 세그먼트 이동 평균 계산
    elif sentence_ps == 'moving_average':
        segment_avg_logits = segment_average_torch(logits, true_indices, sequence_lengths)
        seq_logits = moving_average_segments(segment_avg_logits, window_size)

    logit = []
    for batch in seq_logits:
        v, _ = torch.max(batch, dim=0)
        logit.append(v)
    logit = torch.stack(logit, dim=0)

    return logits, logit, seq_logits



In [11]:

batch_size = 2
seq_len = 12
dim = 4
class_num = 3
pad_token_id = 0
seq_split_id_list = [42000, 42001, 42002, 42003]  # [SPK0]: 42000 [SPK1]: 42001 [SPK2]: 42002 [SPK3]: 42003
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 임의의 hidden_states, input_ids 생성
hidden_states = torch.randn(batch_size, seq_len, dim).to(device)
input_ids = torch.tensor([
    [2, 42000, 1234, 12312, 42002, 5234, 132, 13222, pad_token_id, pad_token_id, pad_token_id, 3],
    [2, 42000, 1234, 12312, 42002, 5234, 132, 13222, 42002, 4232, 1323, 3],
], device=device)

# logits 값을 1씩 증가하는 형태로 초기화
logits = torch.randn(batch_size, seq_len, class_num).to(device)
# logits = torch.arange(1, batch_size * seq_len * class_num + 1, dtype=torch.float).view(batch_size, seq_len, class_num).to(device)

logits, logit, seq_logits = post_processing_module(logits=logits, input_ids= input_ids, device= device, sentence_ps='None') # mean_np_max

# logits, logit, seq_logits = post_processing_module(logits= logits, input_ids= input_ids, device= device, sentence_ps='moving_average', window_size=3) # mean_moving-average(3)_max
# logits, logit, seq_logits = post_processing_module(logits= logits, input_ids= input_ids, device= device, sentence_ps='moving_average', window_size=5) # mean_moving-average(5)_max

torch.Size([2, 12]) tensor([[False,  True, False, False,  True, False, False, False, False, False,
         False, False],
        [False,  True, False, False,  True, False, False, False,  True, False,
         False, False]], device='cuda:0')
[[1, 4], [1, 4, 8]]


In [12]:
logits

tensor([[[-0.4481, -0.1480, -0.9030],
         [ 3.6518, -1.3573, -0.7390],
         [-1.7076, -0.1860,  0.9999],
         [ 1.1357,  1.1619, -0.7755],
         [-0.8124,  0.7639, -0.7912],
         [ 0.3594, -1.0360, -0.7358],
         [ 0.8928, -1.3632,  0.7288],
         [-0.1187, -0.6048,  1.0007],
         [ 0.0888, -2.1452, -0.7284],
         [ 0.0259, -0.8101, -0.0561],
         [-0.3895,  0.0760,  0.0503],
         [-0.1721,  0.9788, -0.9429]],

        [[-1.4614, -0.2747, -1.3924],
         [-0.3280, -2.6116, -0.8565],
         [-1.3301, -2.1722,  0.8784],
         [ 1.7696, -2.2237, -0.8042],
         [ 0.5479,  0.2335,  0.6776],
         [ 0.4886, -1.7863, -0.7142],
         [ 0.8375,  0.1173, -0.8532],
         [ 0.7245, -1.0953, -0.0191],
         [-0.4028, -0.3469, -0.1195],
         [-0.0101, -1.2463,  2.2080],
         [ 0.8413, -0.6801, -0.9111],
         [ 1.3621, -0.2059,  0.8476]]], device='cuda:0')

In [13]:
logit

tensor([[ 1.0266, -0.1271,  0.0506],
        [ 0.6496, -0.2747,  0.5062]], device='cuda:0')

In [17]:
seq_logits

[tensor([[-0.4481, -0.1480, -0.9030],
         [ 0.4859, -0.3745, -0.0446],
         [ 1.0266, -0.1271, -0.1716],
         [ 0.0803, -0.5600,  0.0506]], device='cuda:0'),
 tensor([[-1.4614, -0.2747, -1.3924],
         [ 0.4091, -1.0925,  0.0303],
         [ 0.0372, -2.3358, -0.2608],
         [ 0.5486, -0.6263,  0.1395],
         [ 0.6496, -0.6327, -0.2272],
         [ 0.4476, -0.6198,  0.5062]], device='cuda:0')]

In [19]:
torch.cat(seq_logits, dim=0).shape

torch.Size([10, 3])