In [6]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("codegen25",trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("codegen25")
model.to(device)

prompt = """
write a function to calculate the factorial of a number
"""
inputs = tokenizer(prompt, return_tensors="pt").to(device)

# 生成代码（调整参数优化输出）
generated_ids = model.generate(
    **inputs,
    max_length=4196,           # 调整生成文本的最大长度
    temperature=0.7,          # 控制随机性（0-1，值越小越确定）
    top_p=0.9,                # 核采样参数
    do_sample=True,           # 启用随机采样
    pad_token_id=tokenizer.eos_token_id
)

# 解码生成的代码
generated_code = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(generated_code)


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]


write a function to calculate the factorial of a number
def factorial(n):
    if n == 0:
        return 1
    else:
        return n * factorial(n-1)
#


In [None]:
import ast

class CppCodeGenerator(ast.NodeVisitor):
    def __init__(self):
        self.cpp = []
        self.indent_level = 0
        self.current_function = None
        self.current_class = None
        self.vars = {}

    def add_line(self, line):
        indent = '    ' * self.indent_level
        self.cpp.append(indent + line)

    def get_code(self):
        return '\n'.join(self.cpp)

    def visit_Module(self, node):
        for body_node in node.body:
            self.visit(body_node)

    def visit_FunctionDef(self, node):
        return_type = 'void'  # Default return type
        params = []
        for arg in node.args.args:
            param_type = self.infer_type_from_annotation(arg.annotation) if arg.annotation else 'auto'
            params.append(f"{param_type} {arg.arg}")
        param_str = ', '.join(params)
        docstring = ast.get_docstring(node)
        if docstring:
            self.add_line(f'// {docstring}')
        self.add_line(f'{return_type} {node.name}({param_str}) {{')
        self.indent_level += 1
        for stmt in node.body:
            self.visit(stmt)
        self.indent_level -= 1
        self.add_line('}')

    def visit_ClassDef(self, node):
        self.current_class = node.name
        self.add_line(f'class {node.name} {{')
        self.add_line('public:')
        self.indent_level += 1
        for stmt in node.body:
            if isinstance(stmt, ast.FunctionDef):
                self.visit(stmt)
        self.indent_level -= 1
        self.add_line('};')
        self.current_class = None

    def visit_For(self, node):
        if isinstance(node.iter, ast.Call) and isinstance(node.iter.func, ast.Name) and node.iter.func.id == 'range':
            args = node.iter.args
            loop_var = node.target.id
            if len(args) == 1:
                start = 0
                end = self.visit(args[0])
                self.add_line(f'for (int {loop_var} = {start}; {loop_var} < {end}; ++{loop_var}) {{')
            elif len(args) == 2:
                start = self.visit(args[0])
                end = self.visit(args[1])
                self.add_line(f'for (int {loop_var} = {start}; {loop_var} < {end}; ++{loop_var}) {{')
            elif len(args) == 3:
                start = self.visit(args[0])
                end = self.visit(args[1])
                step = self.visit(args[2])
                self.add_line(f'for (int {loop_var} = {start}; {loop_var} < {end}; {loop_var} += {step}) {{')
            self.indent_level += 1
            for stmt in node.body:
                self.visit(stmt)
            self.indent_level -= 1
            self.add_line('}')

    def visit_While(self, node):
        condition = self.visit(node.test)
        self.add_line(f'while ({condition}) {{')
        self.indent_level += 1
        for stmt in node.body:
            self.visit(stmt)
        self.indent_level -= 1
        self.add_line('}')

    def visit_If(self, node):
        condition = self.visit(node.test)
        self.add_line(f'if ({condition}) {{')
        self.indent_level += 1
        for stmt in node.body:
            self.visit(stmt)
        self.indent_level -= 1
        if node.orelse:
            self.add_line('} else {')
            self.indent_level += 1
            for stmt in node.orelse:
                self.visit(stmt)
            self.indent_level -= 1
            self.add_line('}')
        else:
            self.add_line('}')

    def visit_Expr(self, node):
        if isinstance(node.value, ast.Str):
            self.add_line(f'// {node.value.s}')
        else:
            self.generic_visit(node)

    def visit_Assign(self, node):
        target = node.targets[0]
        var_name = target.id
        value = self.visit(node.value)
        var_type = self.infer_type(node.value)
        self.vars[var_name] = var_type
        self.add_line(f'{var_type} {var_name} = {value};')

    def visit_Return(self, node):
        value = self.visit(node.value) if node.value else ''
        self.add_line(f'return {value};')

    def visit_Call(self, node):
        func_name = self.visit(node.func)
        args = [self.visit(arg) for arg in node.args]
        if func_name == 'print':
            parts = ' << '.join(args)
            return f'std::cout << {parts} << std::endl'
        return f'{func_name}({", ".join(args)})'

    def visit_Name(self, node):
        return node.id

    def visit_Constant(self, node):
        if isinstance(node.value, str):
            return f'"{node.value}"'
        return str(node.value)

    def visit_BinOp(self, node):
        left = self.visit(node.left)
        op = self.visit(node.op)
        right = self.visit(node.right)
        return f'({left} {op} {right})'

    def visit_Add(self, node):
        return '+'

    def infer_type(self, node):
        if isinstance(node, ast.Constant):
            if isinstance(node.value, int):
                return 'int'
            elif isinstance(node.value, float):
                return 'double'
            elif isinstance(node.value, str):
                return 'std::string'
            elif isinstance(node.value, bool):
                return 'bool'
        return 'auto'

    def infer_type_from_annotation(self, annotation):
        if isinstance(annotation, ast.Name):
            return annotation.id
        return 'auto'

def python_to_cpp(python_code):
    try:
        tree = ast.parse(python_code)
        generator = CppCodeGenerator()
        generator.visit(tree)
        return generator.get_code()
    except Exception as e:
        return f"// Error converting Python code to C++: {str(e)}"

python_code="""
def factorial(n):
    if n == 0:
        return 1
    else:
        return n * factorial(n-1)
"""

cpp_code = python_to_cpp(python_code)
print(cpp_code)

// Error converting Python code to C++: 'Attribute' object has no attribute 'id'
