In [2]:
!pip install triton

Collecting triton
  Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.3 kB)
Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.5/209.5 MB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: triton
Successfully installed triton-3.1.0


In [3]:
base_prompt = """
You are an expert PyTorch assistant. When given a code snippet, error message, and context, you will explain why the error occurred and suggest a fix.

Example 1:
Code:
def unsupported_op(x):
    print(x)
    return x + 1
Error:
RuntimeError: I/O functions like 'print' are not supported in TorchScript.
Context:
Tracing with TorchDynamo
Explanation:
The 'print' function cannot be traced during graph compilation because it is a Python I/O operation.
Suggested Fix:
Remove the 'print' statement or replace it with a logging mechanism compatible with tracing.

Example 2:
Code:
def dynamic_control_flow(x):
    if x.mean() > 0.5:
        return x * 2
    else:
        return x / 2
Error:
TypeError: cannot infer the value of 'x' because it depends on runtime data.
Context:
Scripting with TorchDynamo
Explanation:
Dynamic control flow is not supported in TorchScript because it cannot be determined statically.
Suggested Fix:
Replace the dynamic condition with a static equivalent using predefined thresholds.

Now, analyze the following:

"""


In [4]:
import sys
import string
import pprint
from itertools import permutations
import json
from googleapiclient.discovery import build

from bs4 import BeautifulSoup
import urllib.parse
import urllib.request
import requests

import spacy
#from spacy_help_functions import get_entities, create_entity_pairs


# Load pre-trained SpanBERT model
#from spanbert import SpanBERT

import os
import google.generativeai as genai
import time

In [5]:
# Generate response to prompt
def get_gemini_completion(prompt, model_name, max_tokens, temperature, top_p, top_k):

    # Initialize a generative model
    model = genai.GenerativeModel(model_name)

    # Configure the model with your desired parameters
    generation_config=genai.types.GenerationConfig(
        max_output_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k
    )

    # Generate a response
    response = model.generate_content(prompt, generation_config=generation_config)
    print('waiting...')
    time.sleep(1.8)

    return response.text

def remove_punctuation(test_str):
    # Using filter() and lambda function to filter out punctuation characters
    result = ''.join(filter(lambda x: x.isalpha() or x.isdigit() or x.isspace(), test_str))
    return result

In [7]:
prompt_text = """Give a random Shakespeare quote"""

# Feel free to modify the parameters below.
# Documentation: https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini
model_name = 'gemini-pro'
max_tokens = 10000
temperature = 0.3
top_p = 1
top_k = 32

genai.configure(api_key='FILL IN HERE!')

response_text = get_gemini_completion(prompt_text, model_name, max_tokens, temperature, top_p, top_k)
print(response_text)

waiting...
"To be or not to be, that is the question." - Hamlet


In [8]:
def graph_break_analysis(model, inputs):
    result = ""
    print("\n=== Starting Graph Break Analysis ===")
    result += "\n=== Starting Graph Break Analysis ===\n"
    explanation = torch._dynamo.explain(model)(**inputs)
    graph_breaks_count = explanation.graph_break_count
    graph_breaks_reasons = explanation.break_reasons

    if graph_breaks_count == 0:
        print("No graph breaks detected! Your model is fully optimized for torch.compile.")
        result += "No graph breaks detected! Your model is fully optimized for torch.compile.\n"
        return

    print(f"\nTotal Graph Breaks Detected: {graph_breaks_count}")
    result += f"\nTotal Graph Breaks Detected: {graph_breaks_count}\n"

    print("\nGraph Break Reasons:")
    result += "\nGraph Break Reasons:\n"
    for reason in graph_breaks_reasons:
        # print(reason)
        for fs in reason.user_stack:
          detail = get_line_from_stack_frame(fs)
          print(detail)
          result += detail
        print()
        result += "\n"
        print(reason.reason)
        result += reason.reason + "\n"


    print("\n=== End of Graph Break Analysis ===")
    result += "\n=== End of Graph Break Analysis ==="
    return result


def get_line_from_stack_frame(stack_frame):
    file_path = stack_frame.filename
    line_number = stack_frame.lineno

    try:
        with open(file_path, 'r') as file:
            # Read all lines from the file
            lines = file.readlines()
            # Return the specified line (line numbers are 1-indexed)
            if 1 <= line_number <= len(lines):
                return f"Graph break happens at: {lines[line_number - 1].strip()} , File: {file_path}, in line: {line_number}\n"
            else:
                return f"Error: Line number {line_number} is out of range. The file has {len(lines)} lines."
    except FileNotFoundError:
        return f"Error: File '{file_path}' not found."
    except Exception as e:
        return f"An error occurred: {e}"



In [9]:
def graph_break_analysis(model, inputs):
    result = ""
    print("\n=== Starting Graph Break Analysis ===")
    result += "\n=== Starting Graph Break Analysis ===\n"

    # Call the explain function and store the result
    explanation = torch._dynamo.explain(model)(**inputs)

    # Access and print each property of the ExplainOutput object
    print("\n=== Graphs ===")
    result += "\n=== Graphs ===\n"
    result += str(explanation.graphs) + "\n"
    print(explanation.graphs)

    print("\n=== Graph Count ===")
    result += "\n=== Graph Count ===\n"
    result += str(explanation.graph_count) + "\n"
    print(explanation.graph_count)

    print("\n=== Graph Break Count ===")
    result += "\n=== Graph Break Count ===\n"
    result += str(explanation.graph_break_count) + "\n"
    print(explanation.graph_break_count)

    print("\n=== Break Reasons ===")
    result += "\n=== Break Reasons ===\n"
    for reason in explanation.break_reasons:
        for fs in reason.user_stack:
            detail = get_line_from_stack_frame(fs)
            print(detail)
            result += detail
        print(reason.reason)
        result += reason.reason + "\n"

    print("\n=== Operation Count ===")
    result += "\n=== Operation Count ===\n"
    result += str(explanation.op_count) + "\n"
    print(explanation.op_count)

    print("\n=== End of Graph Break Analysis ===")
    result += "\n=== End of Graph Break Analysis ==="
    return result


def get_line_from_stack_frame(stack_frame):
    file_path = stack_frame.filename
    line_number = stack_frame.lineno

    try:
        with open(file_path, 'r') as file:
            lines = file.readlines()
            if 1 <= line_number <= len(lines):
                return f"Graph break happens at: {lines[line_number - 1].strip()} , File: {file_path}, in line: {line_number}\n"
            else:
                return f"Error: Line number {line_number} is out of range. The file has {len(lines)} lines."
    except FileNotFoundError:
        return f"Error: File '{file_path}' not found."
    except Exception as e:
        return f"An error occurred: {e}"


In [10]:
import torch
import triton
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-small")
model.generate = torch.compile(model.generate)


input_text = "summarize: The quick brown fox jumps over the lazy dog. The dog barked loudly at the fox."
inputs = tokenizer(input_text, return_tensors="pt")


######## Graph Break Detect #######
result = graph_break_analysis(model.generate, inputs)
torch._dynamo.reset()

# with torch.no_grad():
#     outputs = model.generate(inputs.input_ids)

# decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
# print("Model Output:", decoded_output)

tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]


=== Starting Graph Break Analysis ===

=== Graphs ===
[GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule()]

=== Graph Count ===
14

=== Graph Break Count ===
13

=== Break Reasons ===
Graph break happens at: result = self._sample( , File: /usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py, in line: 2215

Graph break happens at: model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) , File: /usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py, in line: 3199

Graph break happens at: if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]:  # Exception 1 or Exception 3 , File: /usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py, in line: 384

Dynamic control flow is not supported at the moment. Please use functorch.experimental.cont

In [11]:
prompt_text = """
You are a PyTorch expert assisting in debugging graph breaks in PyTorch's `torch.compile` pipeline. Given the detailed analysis of graph breaks below, explain why each graph break occurred and suggest actionable fixes for them. Ensure the explanations are clear, and the suggestions are practical and compatible with PyTorch 2.x.

### Instructions:
1. Analyze the provided graph break analysis report.
2. For each graph break, provide:
   - An explanation of the root cause in simple terms.
   - A fix or workaround to resolve the graph break.
   - Indicate whether the fix is compatible with `torch.compile`.

### Output Format:
Provide the output in JSON format like this:
{
    "graph_breaks": [
        {
            "location": "Line X, File: /path/to/file.py",
            "cause": "Brief explanation of the cause.",
            "fix": "Detailed fix or workaround.",
            "compatible_with_torch_compile": true/false
        },
        ...
    ]
}

### Graph Break Analysis:
"""
prompt_text += result

# Call the Gemini API
response_text = get_gemini_completion(prompt_text, model_name, max_tokens, temperature, top_p, top_k)

waiting...


In [None]:
prompt_text = """Here is the output of a function that tracks graph break data from using the PyTorch compiler. Please explain the output in understandable terms. For each graph break occurrence, explain what happened, why the graph break happened, and a possible action to address it: """
#prompt_text = base_prompt
prompt_text += result
response_text = get_gemini_completion(prompt_text, model_name, max_tokens, temperature, top_p, top_k)

In [12]:
print(prompt_text)


You are a PyTorch expert assisting in debugging graph breaks in PyTorch's `torch.compile` pipeline. Given the detailed analysis of graph breaks below, explain why each graph break occurred and suggest actionable fixes for them. Ensure the explanations are clear, and the suggestions are practical and compatible with PyTorch 2.x.

### Instructions:
1. Analyze the provided graph break analysis report.
2. For each graph break, provide:
   - An explanation of the root cause in simple terms.
   - A fix or workaround to resolve the graph break.
   - Indicate whether the fix is compatible with `torch.compile`.

### Output Format:
Provide the output in JSON format like this:
{
    "graph_breaks": [
        {
            "location": "Line X, File: /path/to/file.py",
            "cause": "Brief explanation of the cause.",
            "fix": "Detailed fix or workaround.",
            "compatible_with_torch_compile": true/false
        },
        ...
    ]
}

### Graph Break Analysis:

=== Startin

In [13]:
print(response_text)

```json
{
    "graph_breaks": [
        {
            "location": "Line 2215, File: /usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py",
            "cause": "Dynamic control flow is not supported in the `torch.compile` pipeline.",
            "fix": "Use `functorch.experimental.control_flow.cond` to explicitly capture the control flow.",
            "compatible_with_torch_compile": false
        },
        {
            "location": "Line 3199, File: /usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py",
            "cause": "Dynamic control flow is not supported in the `torch.compile` pipeline.",
            "fix": "Use `functorch.experimental.control_flow.cond` to explicitly capture the control flow.",
            "compatible_with_torch_compile": false
        },
        {
            "location": "Line 384, File: /usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py",
            "cause": "Dynamic control flow is not suppor

In [14]:
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load processor and model
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small").to(device)

# Move the model to GPU and compile it
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.generate = torch.compile(model.generate)


# Prepare inputs
inputs = processor(
    text="80s pop track with bassy drums and synth",
    return_tensors="pt"
).to(device)

# Run inference
# with torch.no_grad():
#     outputs = model.generate(**inputs, max_new_tokens=256)
result = graph_break_analysis(model.generate, inputs)


preprocessor_config.json:   0%|          | 0.00/275 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/2.37k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/7.87k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.36G [00:00<?, ?B/s]

  self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False)
Config of the text_encoder: <class 'transformers.models.t5.modeling_t5.T5EncoderModel'> is overwritten by shared text_encoder config: T5Config {
  "_name_or_path": "t5-base",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "classifier_dropout": 0.0,
  "d_ff": 3072,
  "d_kv": 64,
  "d_model": 768,
  "decoder_start_token_id": 0,
  "dense_act_fn": "relu",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "relu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": false,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 12,
  "num_heads": 12,
  "num_layers": 12,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "leng

generation_config.json:   0%|          | 0.00/224 [00:00<?, ?B/s]


=== Starting Graph Break Analysis ===


W1211 22:39:03.923000 316 torch/_dynamo/convert_frame.py:844] [36/8] torch._dynamo hit config.cache_size_limit (8)
W1211 22:39:03.923000 316 torch/_dynamo/convert_frame.py:844] [36/8]    function: 'torch_dynamo_resume_in_forward_at_167' (/usr/local/lib/python3.10/dist-packages/transformers/models/encodec/modeling_encodec.py:167)
W1211 22:39:03.923000 316 torch/_dynamo/convert_frame.py:844] [36/8]    last reason: 36/0: ___check_type_id(L['self']._modules['conv'], 94314477137696)
W1211 22:39:03.923000 316 torch/_dynamo/convert_frame.py:844] [36/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W1211 22:39:03.923000 316 torch/_dynamo/convert_frame.py:844] [36/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.



=== Graphs ===
[GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule(), GraphModule()]

=== Graph Count ===
51

=== Graph Break Count ===
50

=== Break Reasons ===
Graph break happens at: input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( , File: /usr/local/lib/python3.10/dist

In [15]:
prompt_text = """
You are a PyTorch expert assisting in debugging graph breaks in PyTorch's `torch.compile` pipeline. Given the detailed analysis of graph breaks below, explain why each graph break occurred and suggest actionable fixes for them. Ensure the explanations are clear, and the suggestions are practical and compatible with PyTorch 2.x.

### Instructions:
1. Analyze the provided graph break analysis report.
2. For each graph break, provide:
   - An explanation of the root cause in simple terms.
   - A fix or workaround to resolve the graph break.
   - Indicate whether the fix is compatible with `torch.compile`.

### Output Format:
Provide the output in JSON format like this:
{
    "graph_breaks": [
        {
            "location": "Line X, File: /path/to/file.py",
            "cause": "Brief explanation of the cause.",
            "fix": "Detailed fix or workaround.",
            "compatible_with_torch_compile": true/false
        },
        ...
    ]
}

### Graph Break Analysis:
"""
prompt_text += result

# Call the Gemini API
response_text = get_gemini_completion(prompt_text, model_name, max_tokens, temperature, top_p, top_k)

waiting...


In [16]:
print(prompt_text)
print(response_text)


You are a PyTorch expert assisting in debugging graph breaks in PyTorch's `torch.compile` pipeline. Given the detailed analysis of graph breaks below, explain why each graph break occurred and suggest actionable fixes for them. Ensure the explanations are clear, and the suggestions are practical and compatible with PyTorch 2.x.

### Instructions:
1. Analyze the provided graph break analysis report.
2. For each graph break, provide:
   - An explanation of the root cause in simple terms.
   - A fix or workaround to resolve the graph break.
   - Indicate whether the fix is compatible with `torch.compile`.

### Output Format:
Provide the output in JSON format like this:
{
    "graph_breaks": [
        {
            "location": "Line X, File: /path/to/file.py",
            "cause": "Brief explanation of the cause.",
            "fix": "Detailed fix or workaround.",
            "compatible_with_torch_compile": true/false
        },
        ...
    ]
}

### Graph Break Analysis:

=== Startin

In [None]:
prompt_text = """Here is the output of a function that tracks graph break data from using the PyTorch compiler. Please explain the output in understandable terms. For each graph break occurrence, explain what happened, why the graph break happened, and a possible action to address it: """
prompt_text += result
print(prompt_text)
response_text = get_gemini_completion(prompt_text, model_name, max_tokens, temperature, top_p, top_k)
print(response_text)

In [None]:
import torch

def my_function(x):
    # Your model or function here
    if(x>5):
        return x * 2
    else:
        return x / 2

# Get the explanation

graph_break_analysis(my_function, torch.randn(1))
# Get the output
output = my_function(torch.randn(1))
print(output)

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModelForMaskedLM, AutoModelForImageClassification

# Define the Hugging Face models to test
models_to_test = [
    {"name": "t5-small", "type": "seq2seq", "tokenizer": "google-t5/t5-small", "model": "google-t5/t5-small"},
    {"name": "gpt2", "type": "causal", "tokenizer": "gpt2", "model": "gpt2"},
    {"name": "bert-base-uncased", "type": "masked", "tokenizer": "bert-base-uncased", "model": "bert-base-uncased"},
    {"name": "google/vit-base-patch16-224", "type": "image-classification", "tokenizer": None, "model": "google/vit-base-patch16-224"},
]

# Graph Break Analysis
def graph_break_analysis(model, inputs):
    result = {"graphs": None, "graph_count": None, "graph_break_count": None, "break_reasons": [], "op_count": None}
    try:
        explanation = torch._dynamo.explain(model)(**inputs)

        # Store results
        result["graphs"] = str(explanation.graphs)
        result["graph_count"] = explanation.graph_count
        result["graph_break_count"] = explanation.graph_break_count
        result["op_count"] = explanation.op_count

        for reason in explanation.break_reasons:
            for fs in reason.user_stack:
                detail = get_line_from_stack_frame(fs)
                result["break_reasons"].append({"detail": detail, "reason": reason.reason})
    except Exception as e:
        result["error"] = str(e)

    return result

def get_line_from_stack_frame(stack_frame):
    file_path = stack_frame.filename
    line_number = stack_frame.lineno

    try:
        with open(file_path, "r") as file:
            lines = file.readlines()
            if 1 <= line_number <= len(lines):
                return f"Graph break happens at: {lines[line_number - 1].strip()} , File: {file_path}, in line: {line_number}"
            else:
                return f"Error: Line number {line_number} is out of range. The file has {len(lines)} lines."
    except FileNotFoundError:
        return f"Error: File '{file_path}' not found."
    except Exception as e:
        return f"An error occurred: {e}"

# Run analysis for each model
for model_info in models_to_test:
    model_name = model_info["name"]
    model_type = model_info["type"]
    print(f"\n=== Analyzing {model_name} ===")

    # Load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_info["tokenizer"]) if model_info["tokenizer"] else None
    if model_type == "seq2seq":
        model = AutoModelForSeq2SeqLM.from_pretrained(model_info["model"])
    elif model_type == "causal":
        model = AutoModelForCausalLM.from_pretrained(model_info["model"])
    elif model_type == "masked":
        model = AutoModelForMaskedLM.from_pretrained(model_info["model"])
    elif model_type == "image-classification":
        model = AutoModelForImageClassification.from_pretrained(model_info["model"])

    # Prepare inputs
    if model_type in ["seq2seq", "causal", "masked"]:
        input_text = "The quick brown fox jumps over the lazy dog."
        inputs = tokenizer(input_text, return_tensors="pt")
    elif model_type == "image-classification":
        # Use a dummy input tensor for image models
        inputs = {"pixel_values": torch.rand(1, 3, 224, 224)}

    # Compile model and run analysis
    #model = torch.compile(model)
    result = graph_break_analysis(model.generate, inputs)
    torch._dynamo.reset()

    # Print results
    print(f"Results for {model_name}:")
    for key, value in result.items():
        if key == "break_reasons":
            print(f"  {key}:")
            for reason in value:
                print(f"    {reason['detail']}")
                print(f"    Reason: {reason['reason']}")
        else:
            print(f"  {key}: {value}")
    print("\n")
