In [1]:
import sys
from pathlib import Path
import ast
import astor
from importlib.machinery import PathFinder
env_dir = Path(sys.executable).parent.parent.as_posix()

# Patch transformers CLIP model to add a new argument
finder = PathFinder().find_spec('transformers')
if finder is None:
    raise ImportError('transformers not found')
else:
    path = Path(finder.origin).parent / 'models' / 'clip' / 'modeling_clip.py'
    with open(path, 'r') as f:
        tree = ast.parse(f.read())
    for node in ast.walk(tree):
        if isinstance(node, ast.ClassDef) and node.name == 'CLIPTextModel':
            for body in node.body:
                if isinstance(body, ast.FunctionDef) and body.name == 'forward':
                    arg_name, arg_default_val = 'input_embed', None
                    if arg_name not in [arg.arg for arg in body.args.args]:
                        # 1. Add extra argument to the forward method
                        body.args.args.append(ast.arg(arg=arg_name, annotation=ast.Name(id='Optional[torch.Tensor]', ctx=ast.Load())))
                        body.args.defaults.append(ast.Constant(value=arg_default_val))
                        # 2. Modify the call to self.text_model within the forward method
                        for stmt in body.body:
                            if isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.Call):
                                new_arg_expr = ast.keyword(arg=arg_name, value=ast.Name(id=arg_name, ctx=ast.Load()))
                                stmt.value.keywords.append(new_arg_expr)
        
        # 3. Modify the CLIPTextTransformer class to add the new argument to the forward method
        if isinstance(node, ast.ClassDef) and node.name == 'CLIPTextTransformer':
            for body in node.body:
                if isinstance(body, ast.FunctionDef) and body.name == 'forward':
                    arg_name, arg_default_val = 'input_embed', None
                    if arg_name not in [arg.arg for arg in body.args.args]:
                        # 1. Add extra argument to the forward method
                        body.args.args.append(ast.arg(arg=arg_name, annotation=ast.Name(id='Optional[torch.Tensor]', ctx=ast.Load())))
                        body.args.defaults.append(ast.Constant(value=arg_default_val))
                        # 2. Modify the call to self.text_model within the forward method
                        for i, stmt in enumerate(body.body):
                            if isinstance(stmt, ast.Assign) and 'self.embeddings' in astor.to_source(stmt):
                                addition_code = 'if input_embed is not None:\n\thidden_states = input_embed\n'
                                if_block = ast.parse(addition_code).body
                                body.body[i + 1:i + 1] = if_block
    with open(path, 'w') as f:
        f.write(astor.to_source(tree))