In [None]:
base_prompt = """
You are an expert PyTorch assistant. When given a code snippet, error message, and context, you will explain why the error occurred and suggest a fix.

Example 1:
Code:
def unsupported_op(x):
    print(x)
    return x + 1
Error:
RuntimeError: I/O functions like 'print' are not supported in TorchScript.
Context:
Tracing with TorchDynamo
Explanation:
The 'print' function cannot be traced during graph compilation because it is a Python I/O operation.
Suggested Fix:
Remove the 'print' statement or replace it with a logging mechanism compatible with tracing.

Example 2:
Code:
def dynamic_control_flow(x):
    if x.mean() > 0.5:
        return x * 2
    else:
        return x / 2
Error:
TypeError: cannot infer the value of 'x' because it depends on runtime data.
Context:
Scripting with TorchDynamo
Explanation:
Dynamic control flow is not supported in TorchScript because it cannot be determined statically.
Suggested Fix:
Replace the dynamic condition with a static equivalent using predefined thresholds.

Now, analyze the following:

"""


In [None]:
import sys
import string
import pprint
from itertools import permutations
import json
from googleapiclient.discovery import build

from bs4 import BeautifulSoup
import urllib.parse
import urllib.request
import requests

import spacy
#from spacy_help_functions import get_entities, create_entity_pairs


# Load pre-trained SpanBERT model
#from spanbert import SpanBERT

import os
import google.generativeai as genai
import time

In [None]:
import google.generativeai as genai
import time
# Generate response to prompt
def get_gemini_completion(prompt, model_name, max_tokens, temperature, top_p, top_k):

    # Initialize a generative model
    model = genai.GenerativeModel(model_name)

    # Configure the model with your desired parameters
    generation_config=genai.types.GenerationConfig(
        max_output_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k
    )

    # Generate a response
    response = model.generate_content(prompt, generation_config=generation_config)
    print('waiting...')
    time.sleep(1.8)

    return response.text

def remove_punctuation(test_str):
    # Using filter() and lambda function to filter out punctuation characters
    result = ''.join(filter(lambda x: x.isalpha() or x.isdigit() or x.isspace(), test_str))
    return result

In [None]:
prompt_text = """Give a random Shakespeare quote"""

# Feel free to modify the parameters below.
# Documentation: https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini
model_name = 'gemini-pro'
max_tokens = 10000
temperature = 0.3
top_p = 1
top_k = 32

genai.configure(api_key='FILL IN KEY HERE')

response_text = get_gemini_completion(prompt_text, model_name, max_tokens, temperature, top_p, top_k)
print(response_text)

waiting...
"To be or not to be, that is the question." - Hamlet


In [None]:
def graph_break_analysis(model, inputs):
    result = ""
    print("\n=== Starting Graph Break Analysis ===")
    result += "\n=== Starting Graph Break Analysis ===\n"
    explanation = torch._dynamo.explain(model)(**inputs)
    graph_breaks_count = explanation.graph_break_count
    graph_breaks_reasons = explanation.break_reasons

    if graph_breaks_count == 0:
        print("No graph breaks detected! Your model is fully optimized for torch.compile.")
        result += "No graph breaks detected! Your model is fully optimized for torch.compile.\n"
        return

    print(f"\nTotal Graph Breaks Detected: {graph_breaks_count}")
    result += f"\nTotal Graph Breaks Detected: {graph_breaks_count}\n"

    print("\nGraph Break Reasons:")
    result += "\nGraph Break Reasons:\n"
    for reason in graph_breaks_reasons:
        # print(reason)
        for fs in reason.user_stack:
          detail = get_line_from_stack_frame(fs)
          print(detail)
          result += detail
        print()
        result += "\n"
        print(reason.reason)
        result += reason.reason + "\n"


    print("\n=== End of Graph Break Analysis ===")
    result += "\n=== End of Graph Break Analysis ==="
    return result


def get_line_from_stack_frame(stack_frame):
    file_path = stack_frame.filename
    line_number = stack_frame.lineno

    try:
        with open(file_path, 'r') as file:
            # Read all lines from the file
            lines = file.readlines()
            # Return the specified line (line numbers are 1-indexed)
            if 1 <= line_number <= len(lines):
                return f"Graph break happens at: {lines[line_number - 1].strip()} , File: {file_path}, in line: {line_number}\n"
            else:
                return f"Error: Line number {line_number} is out of range. The file has {len(lines)} lines."
    except FileNotFoundError:
        return f"Error: File '{file_path}' not found."
    except Exception as e:
        return f"An error occurred: {e}"



In [None]:
def graph_break_analysis(model, inputs):
    # Check for GPU availability
    if torch.cuda.is_available():
        device = torch.device("cuda")
        model = model.to(device)
        # Move all tensor inputs to GPU
        for k, v in inputs.items():
            if torch.is_tensor(v):
                inputs[k] = v.to(device)
    else:
        device = torch.device("cpu")

    result = ""
    print("\n=== Starting Graph Break Analysis ===")
    result += "\n=== Starting Graph Break Analysis ===\n"

    # Call the explain function and store the result
    explanation = torch._dynamo.explain(model)(**inputs)

    # Access and print each property of the ExplainOutput object
    print("\n=== Graphs ===")
    result += "\n=== Graphs ===\n"
    result += str(explanation.graphs) + "\n"
    print(explanation.graphs)

    print("\n=== Graph Count ===")
    result += "\n=== Graph Count ===\n"
    result += str(explanation.graph_count) + "\n"
    print(explanation.graph_count)

    print("\n=== Graph Break Count ===")
    result += "\n=== Graph Break Count ===\n"
    result += str(explanation.graph_break_count) + "\n"
    print(explanation.graph_break_count)

    print("\n=== Break Reasons ===")
    result += "\n=== Break Reasons ===\n"
    for reason in explanation.break_reasons:
        for fs in reason.user_stack:
            detail = get_line_from_stack_frame(fs)
            print(detail)
            result += detail
        print(reason.reason)
        result += reason.reason + "\n"

    print("\n=== Operation Count ===")
    result += "\n=== Operation Count ===\n"
    result += str(explanation.op_count) + "\n"
    print(explanation.op_count)

    print("\n=== End of Graph Break Analysis ===")
    result += "\n=== End of Graph Break Analysis ==="
    return result


def get_line_from_stack_frame(stack_frame):
    file_path = stack_frame.filename
    line_number = stack_frame.lineno

    try:
        with open(file_path, 'r') as file:
            lines = file.readlines()
            if 1 <= line_number <= len(lines):
                return f"Graph break happens at: {lines[line_number - 1].strip()} , File: {file_path}, in line: {line_number}\n"
            else:
                return f"Error: Line number {line_number} is out of range. The file has {len(lines)} lines."
    except FileNotFoundError:
        return f"Error: File '{file_path}' not found."
    except Exception as e:
        return f"An error occurred: {e}"


In [None]:
from diffusers import DiffusionPipeline
import torch

pipe = DiffusionPipeline.from_pretrained("OFA-Sys/small-stable-diffusion-v0")
# pipe = torch.compile(pipe, mode="default", fullgraph=True, backend="cudagraphs")
result = graph_break_analysis(pipe, {"prompt": "apple"})

prompt = "apple"
image = pipe(prompt).images[0]
image.save("image.png")

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

An error occurred while trying to fetch /root/.cache/huggingface/hub/models--OFA-Sys--small-stable-diffusion-v0/snapshots/38e10e5e71e8fbf717a47a81e7543cd01c1a8140/vae: Error no file named diffusion_pytorch_model.safetensors found in directory /root/.cache/huggingface/hub/models--OFA-Sys--small-stable-diffusion-v0/snapshots/38e10e5e71e8fbf717a47a81e7543cd01c1a8140/vae.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
The config attributes {'predict_epsilon': True} were passed to DPMSolverMultistepScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.
An error occurred while trying to fetch /root/.cache/huggingface/hub/models--OFA-Sys--small-stable-diffusion-v0/snapshots/38e10e5e71e8fbf717a47a81e7543cd01c1a8140/unet: Error no file named diffusion_pytorch_model.safetensors found in directory /root/.cache/huggingface/hub/models--OFA-Sys--small-stable-diffusion-v0/snapshots/38e10e5e71e8fbf717a


=== Starting Graph Break Analysis ===


  torch._dynamo.utils.warn_once(msg)


  0%|          | 0/50 [00:00<?, ?it/s]

W1213 05:58:19.459000 506 torch/_dynamo/variables/tensor.py:776] [21/0] Graph break from `Tensor.item()`, consider setting:
W1213 05:58:19.459000 506 torch/_dynamo/variables/tensor.py:776] [21/0]     torch._dynamo.config.capture_scalar_outputs = True
W1213 05:58:19.459000 506 torch/_dynamo/variables/tensor.py:776] [21/0] or:
W1213 05:58:19.459000 506 torch/_dynamo/variables/tensor.py:776] [21/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W1213 05:58:19.459000 506 torch/_dynamo/variables/tensor.py:776] [21/0] to include these operations in the captured graph.
W1213 05:58:19.459000 506 torch/_dynamo/variables/tensor.py:776] [21/0] 
W1213 05:58:19.459000 506 torch/_dynamo/variables/tensor.py:776] [21/0] Graph break: from user code at:
W1213 05:58:19.459000 506 torch/_dynamo/variables/tensor.py:776] [21/0]   File "/usr/local/lib/python3.10/dist-packages/diffusers/schedulers/scheduling_dpmsolver_multistep.py", line 985, in torch_dynamo_resume_in_index_for_timestep_at_974
W1213 05:58:19.45


=== Graphs ===
[GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule()]

=== Graph Count ===
31

=== Graph Break Count ===
30

=== Break Reasons ===
Graph break happens at: uncond_input = self.tokenizer( , File: /usr/local/lib/python3.10/dist-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py, in line: 469

Graph break happens at: encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs) , File: /usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py, in line: 3021

Graph break happens at: return self.batch_encode_plus( ,

  0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
prompt_text = """
You are a PyTorch expert assisting in debugging graph breaks in PyTorch's `torch.compile` pipeline. Given the detailed analysis of graph breaks below, explain why each graph break occurred and suggest actionable fixes for them. Ensure the explanations are clear, and the suggestions are practical and compatible with PyTorch 2.x.

### Instructions:
1. Analyze the provided graph break analysis report.
2. For each graph break, provide:
   - An explanation of the root cause in simple terms.
   - A fix or workaround to resolve the graph break.
   - Indicate whether the fix is compatible with `torch.compile`.

### Output Format:
Provide the output in JSON format like this:
{
    "graph_breaks": [
        {
            "location": "Line X, File: /path/to/file.py",
            "cause": "Brief explanation of the cause.",
            "fix": "Detailed fix or workaround.",
            "compatible_with_torch_compile": true/false
        },
        ...
    ]
}

### Graph Break Analysis:
"""
prompt_text += result

# Call the Gemini API
response_text = get_gemini_completion(prompt_text, model_name, max_tokens, temperature, top_p, top_k)

waiting...


In [None]:
prompt_text = """Here is the output of a function that tracks graph break data from using the PyTorch compiler. Please explain the output in understandable terms. For each graph break occurrence, explain what happened, why the graph break happened, and a possible action to address it: """
#prompt_text = base_prompt
prompt_text += result
response_text = get_gemini_completion(prompt_text, model_name, max_tokens, temperature, top_p, top_k)

waiting...


In [None]:
print(prompt_text)

Here is the output of a function that tracks graph break data from using the PyTorch compiler. Please explain the output in understandable terms. For each graph break occurrence, explain what happened, why the graph break happened, and a possible action to address it: 
=== Starting Graph Break Analysis ===

=== Graphs ===
[GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule()]

=== Graph Count ===
31

=== Graph Break Count ===
30

=== Break Reasons ===
Graph break happens at: uncond_input = self.tokenizer( , File: /usr/local/lib/python3.10/dist-packages/diffusers/pipelines/stable_dif

In [None]:
print(response_text)

**Graph Break 1:**
* **What happened:** The graph break occurs when the tokenizer is called without any arguments.
* **Why it happened:** The tokenizer expects at least one argument, which is the text to be tokenized.
* **Possible action:** Pass the appropriate text to the tokenizer.

**Graph Break 2:**
* **What happened:** The graph break occurs when the `_call_one` method of the tokenizer is called with the `text_pair` argument set to `None`.
* **Why it happened:** The `_call_one` method expects the `text_pair` argument to be a string or a list of strings.
* **Possible action:** Pass a valid value for the `text_pair` argument.

**Graph Break 3:**
* **What happened:** The graph break occurs when the `batch_encode_plus` method of the tokenizer is called with the `add_special_tokens` argument set to `False`.
* **Why it happened:** The `batch_encode_plus` method expects the `add_special_tokens` argument to be a boolean value.
* **Possible action:** Pass a valid value for the `add_special