In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(10, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [None]:
from torch.fx import symbolic_trace

# 创建模型实例
model = SimpleModel()

# 使用 PyTorch FX 追踪模型
traced = symbolic_trace(model)

# 打印计算图
print(traced.graph)

In [None]:
print(traced.code)

In [None]:
for node in traced.graph.nodes:
    if node.op == 'call_module' and node.target == 'fc1':
        with traced.graph.inserting_after(node):
            new_node = traced.graph.call_function(
                torch.mul, args=(node, 2)
            )
            node.replace_all_uses_with(new_node)
            new_node.replace_input_with(new_node, node)
traced.recompile()
print(traced.graph)

In [None]:
# 创建模型实例
model = SimpleModel()

# 使用 torch.jit.trace 将模型转换为 ScriptModule
example_input = torch.randn(1, 10)
traced_script_module = torch.jit.trace(model, example_input)



In [None]:
model

In [None]:

# 打印 ScriptModule
print(traced_script_module)


In [None]:

# 导出为 TorchScript
traced_script_module.save("simple_model_traced.pt")


In [None]:

# 使用 torch.jit.script 将模型转换为 ScriptModule
scripted_script_module = torch.jit.script(model)

# 打印 ScriptModule
print(scripted_script_module)



In [None]:
# 导出为 TorchScript
scripted_script_module.save("simple_model_scripted.pt")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# 使用 nn.functional 实现的模型
class FunctionalModel(nn.Module):
    def __init__(self):
        super(FunctionalModel, self).__init__()
        self.fc1 = nn.Linear(10, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 使用 nn.Module 实现的模型
class ModuleModel(nn.Module):
    def __init__(self):
        super(ModuleModel, self).__init__()
        self.fc1 = nn.Linear(10, 50)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = self.fc2(x)
        return x

# 实例化模型
functional_model = FunctionalModel()
module_model = ModuleModel()

# 打印模型结构
print("Functional Model:")
print(functional_model)
print("\nModule Model:")
print(module_model)


In [None]:
import transformers
import torch

model_path= "/dataset/crosspipe/llama3-8b"

In [None]:
pipeline = transformers.pipeline(
    "text-generation", model=model_path, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto"
)

In [None]:
pipeline("Hey how are you doing today?",max_length=50)

In [None]:
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
import torch.fx as fx
import time

# 加载LLaMA模型和分词器
model_name = "/dataset/crosspipe/llama-2-chat/Llama-2-7b-chat-hf"
model = LlamaForCausalLM.from_pretrained(model_name)
tokenizer = LlamaTokenizer.from_pretrained(model_name)




In [None]:
# 定义输入数据
input_text = "Hello, how are you?"
inputs = tokenizer(input_text, return_tensors="pt")



In [None]:
model

In [None]:
# 将模型转换为GraphModule
from torch.fx import symbolic_trace
traced_model = symbolic_trace(model)

In [None]:
traced_model

In [64]:
import torch 
from torch import nn
from torch import fx
from torch.fx import symbolic_trace


class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.param=nn.Parameter(torch.Tensor([1,2,3,4]))
        
    def forward(self,x):
        return (x+self.param).clamp(min=0.0,max=1.0)

In [65]:
model=MyModel()

In [66]:
symbolic_traced=symbolic_trace(model)
print(symbolic_traced.graph)

graph():
    %x : [num_users=1] = placeholder[target=x]
    %param : [num_users=1] = get_attr[target=param]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %clamp : [num_users=1] = call_method[target=clamp](args = (%add,), kwargs = {min: 0.0, max: 1.0})
    return clamp


In [51]:
print(symbolic_traced.code)
symbolic_traced.graph.print_tabular()




def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    clamp = add.clamp(min = 0.0, max = 1.0);  add = None
    return clamp
    
opcode         name    target                   args        kwargs
-------------  ------  -----------------------  ----------  ------------------------
placeholder    x       x                        ()          {}
get_attr       param   param                    ()          {}
call_function  add     <built-in function add>  (x, param)  {}
call_method    clamp   clamp                    (add,)      {'min': 0.0, 'max': 1.0}
output         output  output                   (clamp,)    {}


In [58]:
def transform(m):
    gm=fx.Tracer().trace(m)
    for node in gm.nodes:
        if node.op == "call_method":
            if node.target== "clamp":
                print(node.target)
                node.target = "sigmoid"
                node.name ="sigmoid"
                node.kwargs={} 
    gm.lint()
    return fx.GraphModule(m,gm)

trans_model=transform(model)
print(trans_model.graph)
print(trans_model.code)
trans_model.graph.print_tabular()

clamp
graph():
    %x : [num_users=1] = placeholder[target=x]
    %param : [num_users=1] = get_attr[target=param]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %sigmoid : [num_users=1] = call_method[target=sigmoid](args = (%add,), kwargs = {})
    return sigmoid



def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    sigmoid = add.sigmoid();  add = None
    return sigmoid
    
opcode         name     target                   args        kwargs
-------------  -------  -----------------------  ----------  --------
placeholder    x        x                        ()          {}
get_attr       param    param                    ()          {}
call_function  add      <built-in function add>  (x, param)  {}
call_method    sigmoid  sigmoid                  (add,)      {}
output         output   output                   (sigmoid,)  {}


In [59]:
class MyModel1(nn.Module):
    def __init__(self):
        super().__init__()
        self.param=nn.Parameter(torch.Tensor([1,2,3,4]))
        #self.linear=torch.nn.Linear(4,5)
        
    def forward(self,x):
        return (x+self.param).sigmoid()
    
test=MyModel1()


inputs = torch.randn(1,4)
torch.testing.assert_close(test(inputs),trans_model(inputs))


In [62]:
symbolic_traced1=symbolic_trace(test)
symbolic_traced1.graph.print_tabular()


opcode         name     target                   args        kwargs
-------------  -------  -----------------------  ----------  --------
placeholder    x        x                        ()          {}
get_attr       param    param                    ()          {}
call_function  add      <built-in function add>  (x, param)  {}
call_method    sigmoid  sigmoid                  (add,)      {}
output         output   output                   (sigmoid,)  {}


In [72]:
from torch.fx import replace_pattern
def pattern(x):
    return x.clamp(min=0.0,max=1.0)

def replacement(x):
    return x.sigmoid()
replace_pattern(symbolic_traced,pattern,replacement)
print(symbolic_traced.graph.print_tabular())

opcode         name     target                   args        kwargs
-------------  -------  -----------------------  ----------  --------
placeholder    x        x                        ()          {}
get_attr       param    param                    ()          {}
call_function  add      <built-in function add>  (x, param)  {}
call_method    sigmoid  sigmoid                  (add,)      {}
output         output   output                   (sigmoid,)  {}
None


In [73]:
torch.testing.assert_close(test(inputs),symbolic_traced(inputs))

In [None]:
from torch.nn.utils.fusion import fuse_conv_bn_eval
from torch.fx.node import Argument, Target
from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast
import copy

In [None]:
def _parent_name(target : str) -> Tuple[str, str]:
    """
    Splits a qualname into parent path and last atom.
    For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
    """
    *parent, name = target.rsplit('.', 1)
    return parent[0] if parent else '', name


