Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NFC] polish code style #4799

Merged
merged 7 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 15 additions & 19 deletions colossalai/inference/tensor_parallel/policies/chatglm2.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,26 @@
from functools import partial

import torch

from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMForConditionalGeneration,
ChatGLMModel,
GLMBlock,
GLMTransformer,
SelfAttention,
)

# import colossalai
from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy

from ..modeling.chatglm2 import ChatGLM2InferenceForwards, _init_to_get_rotary

try:
from colossalai.kernel.triton.rms_norm import rmsnorm_forward
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
HAS_TRITON_RMSNORM = False


class ChatGLM2InferPolicy(ChatGLMModelPolicy):

def __init__(self) -> None:
super().__init__()

Expand All @@ -32,24 +29,24 @@ def module_policy(self):
self.shard_config._infer()

model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward
method_replacement = {'forward': model_infer_forward}
method_replacement = {"forward": model_infer_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel)

encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward
method_replacement = {'forward': encoder_infer_forward}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=GLMTransformer)
method_replacement = {"forward": encoder_infer_forward}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=GLMTransformer
)

encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward
method_replacement = {'forward': encoder_layer_infer_forward}
method_replacement = {"forward": encoder_layer_infer_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock)

attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward
method_replacement = {'forward': attn_infer_forward}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=SelfAttention)
method_replacement = {"forward": attn_infer_forward}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=SelfAttention
)

# for rmsnorm and others, we need to check the shape
return policy
Expand All @@ -60,17 +57,16 @@ def postprocess(self):


class ChatGLM2ForConditionalGenerationInferPolicy(ChatGLM2InferPolicy):

def __init__(self) -> None:
super().__init__()

def module_policy(self):
policy = super().module_policy()
model_infer_forward = ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward
method_replacement = {'forward': partial(model_infer_forward)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=ChatGLMForConditionalGeneration)
method_replacement = {"forward": partial(model_infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=ChatGLMForConditionalGeneration
)
return policy

def postprocess(self):
Expand Down
9 changes: 4 additions & 5 deletions colossalai/kernel/cuda_native/csrc/gptq/tuning.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
#ifndef _tuning_h
#define _tuning_h

struct ExLlamaTuning
{
int matmul_recons_thd;
bool matmul_fused_remap;
bool matmul_no_half2;
struct ExLlamaTuning {
int matmul_recons_thd;
bool matmul_fused_remap;
bool matmul_no_half2;
};

#endif
Loading