In [2]:
import torch
import torch.nn as nn
import time

# GPU 사용 가능 여부 확인
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 게이팅 네트워크 정의
class GatingNetwork(nn.Module):
    def __init__(self, input_dim, num_experts, add_noise=False, noise_std=1.0):
        super(GatingNetwork, self).__init__()
        self.layer = nn.Linear(input_dim, num_experts)
        self.add_noise = add_noise
        self.noise_std = noise_std

    def forward(self, x):
        gate_logits = self.layer(x)  # [batch_size, num_experts]
        if self.add_noise and self.training:
            noise = torch.randn_like(gate_logits) * self.noise_std
            gate_logits = gate_logits + noise
        gate_probs = torch.softmax(gate_logits, dim=-1)  # [batch_size, num_experts]
        return gate_probs  # [batch_size, num_experts]

# MoE with separate nn.Linear layers for each expert (True Sparse MoE)
class MoE_Sparse_True(nn.Module):
    def __init__(self, input_dim, output_dim, num_experts, topk=2, add_noise=False, noise_std=1.0):
        super(MoE_Sparse_True, self).__init__()
        self.num_experts = num_experts
        self.topk = topk

        # 전문가들을 개별 nn.Linear 모듈로 정의
        self.experts = nn.ModuleList([nn.Linear(input_dim, output_dim).to(device) for _ in range(num_experts)])
        for expert in self.experts:
            nn.init.xavier_uniform_(expert.weight)
            nn.init.zeros_(expert.bias)

        self.gate = GatingNetwork(input_dim, num_experts, add_noise=add_noise, noise_std=noise_std)

    def forward(self, x):
        batch_size = x.size(0)

        # 1. 게이팅 네트워크를 통해 전문가 선택 확률 계산
        gate_probs = self.gate(x)  # [batch_size, num_experts]

        # 2. Topk 전문가 선택
        topk_probs, topk_indices = torch.topk(gate_probs, self.topk, dim=-1)  # [batch_size, topk]

        # 3. 초기 출력 텐서 생성
        output = torch.zeros(batch_size, self.experts[0].out_features, device=x.device)

        # 4. 배치와 Topk를 평탄화
        flat_indices = topk_indices.view(-1)  # [batch_size * topk]
        flat_probs = topk_probs.view(-1)      # [batch_size * topk]

        # 5. 고유한 전문가 식별
        unique_experts = torch.unique(flat_indices)

        # 6. 각 고유 전문가별로 출력을 계산하고 누적
        for expert_id in unique_experts:
            # 해당 전문가가 선택된 위치 마스크 생성
            mask = (flat_indices == expert_id)  # [batch_size * topk]
            if mask.sum() == 0:
                continue
            indices = mask.nonzero(as_tuple=False).view(-1)  # [count]

            # 배치 인덱스 계산 (0부터 batch_size-1까지)
            batch_indices = (indices // self.topk).long()  # [count]

            # 선택된 입력과 확률 추출
            x_selected = x[batch_indices]               # [count, input_dim]
            probs_selected = flat_probs[indices].unsqueeze(1)  # [count, 1]

            # 전문가 출력 계산
            expert_output = self.experts[expert_id](x_selected)  # [count, output_dim]

            # 출력 누적 (배치 인덱스를 사용하여 직접 인덱싱)
            output[batch_indices] += expert_output * probs_selected  # [batch_size, output_dim]

        return output  # [batch_size, output_dim]

# MoE with a single combined nn.Linear layer (for reference)
class MoE_Combined(nn.Module):
    def __init__(self, input_dim, output_dim, num_experts, topk=2, add_noise=False, noise_std=1.0):
        super(MoE_Combined, self).__init__()
        self.num_experts = num_experts
        self.topk = topk

        # 모든 전문가를 하나의 큰 nn.Linear 레이어로 통합
        self.experts = nn.Linear(input_dim, output_dim * num_experts)
        nn.init.xavier_uniform_(self.experts.weight)
        nn.init.zeros_(self.experts.bias)

        self.gate = GatingNetwork(input_dim, num_experts, add_noise=add_noise, noise_std=noise_std)

    def forward(self, x):
        batch_size = x.size(0)

        # 1. 게이팅 네트워크를 통해 전문가 선택 확률 계산
        gate_probs = self.gate(x)  # [batch_size, num_experts]

        # 2. Topk 전문가 선택
        topk_probs, topk_indices = torch.topk(gate_probs, self.topk, dim=-1)  # [batch_size, topk]

        # 3. 전문가 출력 계산 (모든 전문가의 출력을 한 번에 계산)
        expert_outputs = self.experts(x)  # [batch_size, num_experts * output_dim]
        expert_outputs = expert_outputs.view(batch_size, self.num_experts, -1)  # [batch_size, num_experts, output_dim]

        # 4. 선택된 Topk 전문가의 출력 추출
        batch_indices = torch.arange(batch_size).unsqueeze(1).expand(-1, self.topk).to(x.device)  # [batch_size, topk]
        selected_expert_outputs = expert_outputs[batch_indices, topk_indices]  # [batch_size, topk, output_dim]

        # 5. 전문가 출력에 확률 가중치 적용
        topk_probs = topk_probs.unsqueeze(-1)  # [batch_size, topk, 1]
        weighted_expert_outputs = selected_expert_outputs * topk_probs  # [batch_size, topk, output_dim]

        # 6. 최종 출력 계산 (Topk 전문가의 가중합)
        output = weighted_expert_outputs.sum(dim=1)  # [batch_size, output_dim]

        return output  # [batch_size, output_dim]

# GPU 메모리 사용량 측정 함수
def get_gpu_memory():
    return torch.cuda.memory_allocated(device) / (1024 ** 2)  # MB 단위로 변환

# GPU 메모리 초기화 함수
def reset_gpu_memory():
    torch.cuda.reset_peak_memory_stats(device)
    torch.cuda.empty_cache()

# 모델 비교 테스트 함수
def compare_models(model_a, model_b, input_dim, batch_size, num_iterations=100):
    # 모델을 GPU로 이동
    model_a.to(device)
    model_b.to(device)

    # 모델을 평가 모드로 설정
    model_a.eval()
    model_b.eval()

    # 더미 입력 데이터 생성
    x = torch.randn(batch_size, input_dim).to(device)

    # 워밍업
    with torch.no_grad():
        model_a(x)
        model_b(x)

    # GPU 메모리 초기화
    reset_gpu_memory()

    # 모델 A (Combined) 측정
    start_event_a = torch.cuda.Event(enable_timing=True)
    end_event_a = torch.cuda.Event(enable_timing=True)

    start_event_a.record()
    with torch.no_grad():
        for _ in range(num_iterations):
            output_a = model_a(x)
    end_event_a.record()
    torch.cuda.synchronize()
    elapsed_time_a = start_event_a.elapsed_time(end_event_a)  # ms 단위
    peak_mem_a = torch.cuda.max_memory_allocated(device) / (1024 ** 2)  # MB 단위

    # GPU 메모리 초기화
    reset_gpu_memory()

    # 모델 B (Sparse) 측정
    start_event_b = torch.cuda.Event(enable_timing=True)
    end_event_b = torch.cuda.Event(enable_timing=True)

    start_event_b.record()
    with torch.no_grad():
        for _ in range(num_iterations):
            output_b = model_b(x)
    end_event_b.record()
    torch.cuda.synchronize()
    elapsed_time_b = start_event_b.elapsed_time(end_event_b)  # ms 단위
    peak_mem_b = torch.cuda.max_memory_allocated(device) / (1024 ** 2)  # MB 단위

    # 결과 출력
    print(f"Model A (Combined nn.Linear):")
    print(f"  Total Time for {num_iterations} iterations: {elapsed_time_a:.2f} ms")
    print(f"  Peak GPU Memory Usage: {peak_mem_a:.2f} MB\n")

    print(f"Model B (Sparse MoE with Separate nn.Linear):")
    print(f"  Total Time for {num_iterations} iterations: {elapsed_time_b:.2f} ms")
    print(f"  Peak GPU Memory Usage: {peak_mem_b:.2f} MB\n")

# 메인 함수
def main():
    # 파라미터 설정
    input_dim = 128
    output_dim = 256
    num_experts = 10
    batch_size = 32
    topk = 2
    add_noise = False  # 공정한 비교를 위해 노이즈 비활성화
    noise_std = 1.0
    num_iterations = 100

    # 모델 인스턴스화
    model_combined = MoE_Combined(input_dim, output_dim, num_experts, topk=topk, add_noise=add_noise, noise_std=noise_std).to(device)
    model_sparse = MoE_Sparse_True(input_dim, output_dim, num_experts, topk=topk, add_noise=add_noise, noise_std=noise_std).to(device)

    # 모델 비교 실행
    compare_models(model_combined, model_sparse, input_dim, batch_size, num_iterations=num_iterations)

if __name__ == "__main__":
    main()


Using device: cuda
Model A (Combined nn.Linear):
  Total Time for 100 iterations: 26.97 ms
  Peak GPU Memory Usage: 12.02 MB

Model B (Sparse MoE with Separate nn.Linear):
  Total Time for 100 iterations: 257.81 ms
  Peak GPU Memory Usage: 11.79 MB

