Skip to content

Commit

Permalink
[NFC] polish code style (#4799)
Browse files Browse the repository at this point in the history
  • Loading branch information
Camille7777 committed Sep 27, 2023
1 parent da3310f commit b9e1fef
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 24 deletions.
34 changes: 15 additions & 19 deletions colossalai/inference/tensor_parallel/policies/chatglm2.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,26 @@
from functools import partial

import torch

from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMForConditionalGeneration,
ChatGLMModel,
GLMBlock,
GLMTransformer,
SelfAttention,
)

# import colossalai
from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy

from ..modeling.chatglm2 import ChatGLM2InferenceForwards, _init_to_get_rotary

try:
from colossalai.kernel.triton.rms_norm import rmsnorm_forward
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
HAS_TRITON_RMSNORM = False


class ChatGLM2InferPolicy(ChatGLMModelPolicy):

def __init__(self) -> None:
super().__init__()

Expand All @@ -32,24 +29,24 @@ def module_policy(self):
self.shard_config._infer()

model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward
method_replacement = {'forward': model_infer_forward}
method_replacement = {"forward": model_infer_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel)

encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward
method_replacement = {'forward': encoder_infer_forward}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=GLMTransformer)
method_replacement = {"forward": encoder_infer_forward}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=GLMTransformer
)

encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward
method_replacement = {'forward': encoder_layer_infer_forward}
method_replacement = {"forward": encoder_layer_infer_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock)

attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward
method_replacement = {'forward': attn_infer_forward}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=SelfAttention)
method_replacement = {"forward": attn_infer_forward}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=SelfAttention
)

# for rmsnorm and others, we need to check the shape
return policy
Expand All @@ -60,17 +57,16 @@ def postprocess(self):


class ChatGLM2ForConditionalGenerationInferPolicy(ChatGLM2InferPolicy):

def __init__(self) -> None:
super().__init__()

def module_policy(self):
policy = super().module_policy()
model_infer_forward = ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward
method_replacement = {'forward': partial(model_infer_forward)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=ChatGLMForConditionalGeneration)
method_replacement = {"forward": partial(model_infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=ChatGLMForConditionalGeneration
)
return policy

def postprocess(self):
Expand Down
9 changes: 4 additions & 5 deletions colossalai/kernel/cuda_native/csrc/gptq/tuning.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
#ifndef _tuning_h
#define _tuning_h

struct ExLlamaTuning
{
int matmul_recons_thd;
bool matmul_fused_remap;
bool matmul_no_half2;
struct ExLlamaTuning {
int matmul_recons_thd;
bool matmul_fused_remap;
bool matmul_no_half2;
};

#endif

0 comments on commit b9e1fef

Please sign in to comment.